diff --git a/README.md b/README.md index b1af641..3b1a688 100644 --- a/README.md +++ b/README.md @@ -179,12 +179,11 @@ The `Adapter` API lets you **easily patch models** by injecting parameters in ta E.g. to inject LoRA layers in all attention's linear layers: ```python -from refiners.adapters.lora import LoraAdapter +from refiners.adapters.lora import SingleLoraAdapter for layer in vit.layers(fl.Attention): for linear, parent in layer.walk(fl.Linear): - adapter = LoraAdapter(target=linear, rank=64) - adapter.inject(parent) + SingleLoraAdapter(target=linear, rank=64).inject(parent) # ... and load existing weights if the LoRAs are pretrained ... ``` @@ -232,7 +231,7 @@ Step 3: run inference using the GPU: ```python from refiners.foundationals.latent_diffusion import StableDiffusion_1 -from refiners.foundationals.latent_diffusion.lora import LoraWeights +from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.fluxion.utils import load_from_safetensors, manual_seed import torch @@ -242,9 +241,7 @@ sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors") sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors")) sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors")) -# This uses the LoraAdapter internally and takes care to inject it where it should -lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device) -lora_weights.patch(sd15, scale=1.0) +SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject() prompt = "a cute cat" diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index 3f24d0e..50935ed 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -8,7 +8,7 @@ from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.model_converter import ModelConverter from refiners.foundationals.latent_diffusion import ( SD1UNet, - SD1Controlnet, + SD1ControlnetAdapter, DPMSolver, ) @@ -21,13 +21,13 @@ class Args(argparse.Namespace): @torch.no_grad() def convert(args: Args) -> dict[str, torch.Tensor]: controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore - controlnet = SD1Controlnet(name="mycn") + unet = SD1UNet(in_channels=4, clip_embedding_dim=768) + adapter = SD1ControlnetAdapter(unet, name="mycn").inject() + controlnet = unet.Controlnet condition = torch.randn(1, 3, 512, 512) - controlnet.set_controlnet_condition(condition=condition) + adapter.set_controlnet_condition(condition=condition) - unet = SD1UNet(in_channels=4, clip_embedding_dim=768) - unet.insert(index=0, module=controlnet) clip_text_embedding = torch.rand(1, 77, 768) unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py index 2942393..f0052d5 100644 --- a/scripts/conversion/convert_diffusers_lora.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -1,17 +1,20 @@ import argparse from pathlib import Path from typing import cast + import torch from torch import Tensor from torch.nn.init import zeros_ from torch.nn import Parameter as TorchParameter + from diffusers import DiffusionPipeline # type: ignore + import refiners.fluxion.layers as fl from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.utils import save_to_safetensors +from refiners.adapters.lora import Lora, LoraAdapter from refiners.foundationals.latent_diffusion import SD1UNet -from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target -from refiners.adapters.lora import Lora +from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets def get_weight(linear: fl.Linear) -> torch.Tensor: @@ -69,7 +72,8 @@ def process(args: Args) -> None: diffusers_to_refiners = converter.get_mapping() - apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0) + LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject() + for layer in refiners_model.layers(layer_type=Lora): zeros_(tensor=layer.Linear_1.weight) @@ -85,7 +89,9 @@ def process(args: Args) -> None: p = p[seg] assert isinstance(p, fl.Chain) last_seg = ( - "LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}" + "SingleLoraAdapter" + if orig_path[-1] == "Linear" + else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}" ) p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"]) p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"]) diff --git a/scripts/training/finetune-ldm-lora.py b/scripts/training/finetune-ldm-lora.py index fb6e030..19c07b8 100644 --- a/scripts/training/finetune-ldm-lora.py +++ b/scripts/training/finetune-ldm-lora.py @@ -2,9 +2,8 @@ import random from typing import Any from pydantic import BaseModel from loguru import logger -from refiners.adapters.lora import LoraAdapter, Lora from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets +from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS import refiners.fluxion.layers as fl from torch import Tensor from torch.utils.data import Dataset @@ -27,13 +26,6 @@ class LoraConfig(BaseModel): text_encoder_targets: list[LoraTarget] lda_targets: list[LoraTarget] - def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None: - for linear, parent in lora_targets(module, target): - adapter = LoraAdapter(target=linear, rank=self.rank) - adapter.inject(parent) - for linear in adapter.Lora.layers(fl.Linear): - linear.requires_grad_(requires_grad=True) - class TriggerPhraseDataset(TextEmbeddingLatentsDataset): def __init__( @@ -84,45 +76,30 @@ class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfi class LoadLoras(Callback[LoraLatentDiffusionTrainer]): def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None: lora_config = trainer.config.lora - for target in lora_config.unet_targets: - lora_config.apply_loras_to_target(module=trainer.unet, target=target) - for target in lora_config.text_encoder_targets: - lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target) - for target in lora_config.lda_targets: - lora_config.apply_loras_to_target(module=trainer.lda, target=target) + + for model_name in MODELS: + model = getattr(trainer, model_name) + adapter = LoraAdapter[type(model)]( + model, + sub_targets=getattr(lora_config, f"{model_name}_targets"), + rank=lora_config.rank, + ) + for sub_adapter, _ in adapter.sub_adapters: + for linear in sub_adapter.Lora.layers(fl.Linear): + linear.requires_grad_(requires_grad=True) + adapter.inject() class SaveLoras(Callback[LoraLatentDiffusionTrainer]): def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None: - lora_config = trainer.config.lora - - def get_weight(linear: fl.Linear) -> Tensor: - assert linear.bias is None - return linear.state_dict()["weight"] - - def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]: - weights: list[Tensor] = [] - for lora in module.layers(layer_type=Lora): - linears = list(lora.layers(fl.Linear)) - assert len(linears) == 2 - # See `load_lora_weights` in refiners.adapters.lora - weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight) - return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)} - tensors: dict[str, Tensor] = {} metadata: dict[str, str] = {} - if lora_config.unet_targets: - tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.") - metadata |= {"unet_targets": ",".join(lora_config.unet_targets)} - - if lora_config.text_encoder_targets: - tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.") - metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)} - - if lora_config.lda_targets: - tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.") - metadata |= {"lda_targets": ",".join(lora_config.lda_targets)} + for model_name in MODELS: + model = getattr(trainer, model_name) + adapter = model.parent + tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)} + metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)} save_to_safetensors( path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", diff --git a/src/refiners/adapters/adapter.py b/src/refiners/adapters/adapter.py index 9c76750..5a853aa 100644 --- a/src/refiners/adapters/adapter.py +++ b/src/refiners/adapters/adapter.py @@ -4,7 +4,7 @@ from typing import Any, Generic, TypeVar, Iterator T = TypeVar("T", bound=fl.Module) -TAdapter = TypeVar("TAdapter", bound="Adapter[fl.Module]") +TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673) class Adapter(Generic[T]): @@ -36,27 +36,61 @@ class Adapter(Generic[T]): yield target._can_refresh_parent = _old_can_refresh_parent - def inject(self, parent: fl.Chain | None = None) -> None: + def lookup_actual_target(self) -> fl.Module: + # In general, the "actual target" is the target. + # This method deals with the edge case where the target + # is part of the replacement block and has been adapted by + # another adapter after this one. For instance, this is the + # case when stacking Controlnets. assert isinstance(self, fl.Chain) + target_parent = self.find_parent(self.target) + if (target_parent is None) or (target_parent == self): + return self.target + + # Lookup and return last adapter in parents tree (or target if none). + r, p = self.target, target_parent + while p != self: + if isinstance(p, Adapter): + r = p + assert p.parent, f"parent tree of {self} is broken" + p = p.parent + return r + + def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter: + assert isinstance(self, fl.Chain) + + if (parent is None) and isinstance(self.target, fl.ContextModule): + parent = self.target.parent + if parent is not None: + assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}" + + target_parent = self.find_parent(self.target) + if parent is None: if isinstance(self.target, fl.ContextModule): - parent = self.target.parent - else: - raise ValueError(f"parent of {self.target} is mandatory") - assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}" + self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage] + return self + if self.target not in iter(parent): raise ValueError(f"{self.target} is not in {parent}") parent.replace( old_module=self.target, new_module=self, - old_module_parent=self.find_parent(self.target), + old_module_parent=target_parent, ) + return self def eject(self) -> None: assert isinstance(self, fl.Chain) - self.ensure_parent.replace(old_module=self, new_module=self.target) + actual_target = self.lookup_actual_target() + + if (parent := self.parent) is None: + if isinstance(actual_target, fl.ContextModule): + actual_target._set_parent(None) # type: ignore[reportPrivateUsage] + else: + parent.replace(old_module=self, new_module=actual_target) def _pre_structural_copy(self) -> None: if isinstance(self.target, fl.Chain): diff --git a/src/refiners/adapters/lora.py b/src/refiners/adapters/lora.py index d7f167d..b157fed 100644 --- a/src/refiners/adapters/lora.py +++ b/src/refiners/adapters/lora.py @@ -1,8 +1,14 @@ +from typing import Iterable, Generic, TypeVar, Any + import refiners.fluxion.layers as fl from refiners.adapters.adapter import Adapter -from torch.nn.init import zeros_, normal_ from torch import Tensor, device as Device, dtype as DType +from torch.nn import Parameter as TorchParameter +from torch.nn.init import zeros_, normal_ + +T = TypeVar("T", bound=fl.Chain) +TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673) class Lora(fl.Chain): @@ -37,11 +43,19 @@ class Lora(fl.Chain): self.scale = scale def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: - self.Linear_1.weight = down_weight - self.Linear_2.weight = up_weight + self.Linear_1.weight = TorchParameter(down_weight) + self.Linear_2.weight = TorchParameter(up_weight) + + @property + def up_weight(self) -> Tensor: + return self.Linear_2.weight.data + + @property + def down_weight(self) -> Tensor: + return self.Linear_1.weight.data -class LoraAdapter(fl.Sum, Adapter[fl.Linear]): +class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]): structural_attrs = ["in_features", "out_features", "rank", "scale"] def __init__( @@ -67,20 +81,54 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]): ) self.Lora.set_scale(scale=scale) - def add_lora(self, lora: Lora) -> None: - self.append(module=lora) - def load_lora_weights(self, up_weight: Tensor, down_weight: Tensor, index: int = 0) -> None: - self[index + 1].load_weights(up_weight=up_weight, down_weight=down_weight) +class LoraAdapter(Generic[T], fl.Chain, Adapter[T]): + def __init__( + self, + target: T, + sub_targets: Iterable[tuple[fl.Linear, fl.Chain]], + rank: int | None = None, + scale: float = 1.0, + weights: list[Tensor] | None = None, + ) -> None: + with self.setup_adapter(target): + super().__init__(target) + if weights is not None: + assert len(weights) % 2 == 0 + weights_rank = weights[0].shape[1] + if rank is None: + rank = weights_rank + else: + assert rank == weights_rank -def load_lora_weights(model: fl.Chain, weights: list[Tensor]) -> None: - assert len(weights) % 2 == 0, "Number of weights must be even" - assert ( - len(list(model.layers(layer_type=Lora))) == len(weights) // 2 - ), "Number of Lora layers must match number of weights" - for i, lora in enumerate(iterable=model.layers(layer_type=Lora)): - assert ( - lora.rank == weights[i * 2].shape[1] - ), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}" - lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1]) + assert rank is not None, "either pass a rank or weights" + + self.sub_targets = sub_targets + self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = [] + + for linear, parent in self.sub_targets: + self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent)) + + if weights is not None: + assert len(self.sub_adapters) == (len(weights) // 2) + for i, (adapter, _) in enumerate(self.sub_adapters): + lora = adapter.Lora + assert ( + lora.rank == weights[i * 2].shape[1] + ), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}" + adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1]) + + def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter: + for adapter, adapter_parent in self.sub_adapters: + adapter.inject(adapter_parent) + return super().inject(parent) + + def eject(self) -> None: + for adapter, _ in self.sub_adapters: + adapter.eject() + super().eject() + + @property + def weights(self) -> list[Tensor]: + return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]] diff --git a/src/refiners/foundationals/clip/concepts.py b/src/refiners/foundationals/clip/concepts.py index 586f5eb..c178c7d 100644 --- a/src/refiners/foundationals/clip/concepts.py +++ b/src/refiners/foundationals/clip/concepts.py @@ -10,7 +10,7 @@ from torch.nn import Parameter import re -class ConceptExtender: +class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]): """ Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique. @@ -37,6 +37,9 @@ class ConceptExtender: """ def __init__(self, target: CLIPTextEncoder) -> None: + with self.setup_adapter(target): + super().__init__(target) + try: token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder)) except StopIteration: @@ -54,13 +57,15 @@ class ConceptExtender: self.embedding_extender.add_embedding(embedding) self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1) - def inject(self) -> None: + def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender": self.embedding_extender.inject(self.token_encoder_parent) self.token_extender.inject(self.clip_tokenizer_parent) + return super().inject(parent) def eject(self) -> None: self.embedding_extender.eject() self.token_extender.eject() + super().eject() class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index cd2d6d1..8caf5de 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -9,7 +9,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( StableDiffusion_1, StableDiffusion_1_Inpainting, SD1UNet, - SD1Controlnet, + SD1ControlnetAdapter, ) from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( SDXLUNet, @@ -21,7 +21,7 @@ __all__ = [ "StableDiffusion_1", "StableDiffusion_1_Inpainting", "SD1UNet", - "SD1Controlnet", + "SD1ControlnetAdapter", "SDXLUNet", "DoubleTextEncoder", "DPMSolver", diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 929cf49..910da40 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -2,17 +2,26 @@ from enum import Enum from pathlib import Path from typing import Iterator +from torch import Tensor -from torch import Tensor, device as Device -from torch.nn import Parameter as TorchParameter -from refiners.adapters.lora import LoraAdapter, load_lora_weights -from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer -from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d -from refiners.foundationals.latent_diffusion import StableDiffusion_1 -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet import refiners.fluxion.layers as fl from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors +from refiners.adapters.adapter import Adapter +from refiners.adapters.lora import SingleLoraAdapter, LoraAdapter + +from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d +from refiners.foundationals.latent_diffusion import ( + StableDiffusion_1, + SD1UNet, + CLIPTextEncoderL, + LatentDiffusionAutoencoder, +) +from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet + +MODELS = ["unet", "text_encoder", "lda"] + class LoraTarget(str, Enum): Self = "self" @@ -38,66 +47,95 @@ class LoraTarget(str, Enum): return TransformerLayer -def get_lora_rank(weights: list[Tensor]) -> int: - ranks: set[int] = {w.shape[1] for w in weights[0::2]} - assert len(ranks) == 1 - return ranks.pop() - - def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]: - it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class()) - for layer in it: + lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class() + + if isinstance(module, SD1UNet): + + def predicate(m: fl.Module, p: fl.Chain) -> bool: + if isinstance(m, Controlnet): # do not adapt Controlnet linears + raise StopIteration + return isinstance(m, lookup_class) + + else: + + def predicate(m: fl.Module, p: fl.Chain) -> bool: + return isinstance(m, lookup_class) + + if target == LoraTarget.Self: + for m, p in module.walk(predicate): + assert isinstance(m, fl.Linear) + yield (m, p) + return + + for layer, _ in module.walk(predicate): for t in layer.walk(fl.Linear): yield t -def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None: - for linear, parent in lora_targets(module, target): - adapter = LoraAdapter(target=linear, rank=rank, scale=scale) - adapter.inject(parent) - - -class LoraWeights: - """A single LoRA weights training checkpoint used to patch a Stable Diffusion 1.5 model.""" - +class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): metadata: dict[str, str] | None tensors: dict[str, Tensor] - def __init__(self, checkpoint_path: Path | str, device: Device | str): - self.metadata = load_metadata_from_safetensors(checkpoint_path) - self.tensors = load_from_safetensors(checkpoint_path, device=device) + def __init__( + self, + target: StableDiffusion_1, + sub_targets: dict[str, list[LoraTarget]], + scale: float = 1.0, + weights: dict[str, Tensor] | None = None, + ): + with self.setup_adapter(target): + super().__init__(target) - def patch(self, sd: StableDiffusion_1, scale: float = 1.0) -> None: - assert self.metadata is not None, "Invalid safetensors checkpoint: missing metadata" + self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = [] - for meta_key, meta_value in self.metadata.items(): - match meta_key: - case "unet_targets": - # TODO: support this transparently - if any([isinstance(module, SD1Controlnet) for module in sd.unet]): - raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter") - model = sd.unet - key_prefix = "unet." - case "text_encoder_targets": - model = sd.clip_text_encoder - key_prefix = "text_encoder." - case "lda_targets": - model = sd.lda - key_prefix = "lda." - case _: - raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}") + for model_name in MODELS: + if not (model_targets := sub_targets.get(model_name, [])): + continue + model = getattr(target, model_name) + if model.find(SingleLoraAdapter): + raise NotImplementedError(f"{model} already contains LoRA layers") - # TODO(FG-487): support loading multiple LoRA-s - if any(model.layers(LoraAdapter)): - raise NotImplementedError(f"{model.__class__.__name__} already contains LoRA layers") + lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None + self.sub_adapters.append( + LoraAdapter[type(model)]( + model, + sub_targets=[x for target in model_targets for x in lora_targets(model, target)], + scale=scale, + weights=lora_weights, + ) + ) - lora_weights = [w for w in [self.tensors[k] for k in sorted(self.tensors) if k.startswith(key_prefix)]] - assert len(lora_weights) % 2 == 0 + @classmethod + def from_safetensors( + cls, + target: StableDiffusion_1, + checkpoint_path: Path | str, + scale: float = 1.0, + ): + metadata = load_metadata_from_safetensors(checkpoint_path) + assert metadata is not None, "Invalid safetensors checkpoint: missing metadata" + tensors = load_from_safetensors(checkpoint_path, device=target.device) - rank = get_lora_rank(lora_weights) - for target in meta_value.split(","): - apply_loras_to_target(model, target=LoraTarget(target), rank=rank, scale=scale) + sub_targets = {} + for model_name in MODELS: + if not (v := metadata.get(f"{model_name}_targets", "")): + continue + sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")] - assert len(list(model.layers(LoraAdapter))) == (len(lora_weights) // 2) + return cls( + target, + sub_targets, + scale=scale, + weights=tensors, + ) - load_lora_weights(model, [TorchParameter(w) for w in lora_weights]) + def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter": + for adapter in self.sub_adapters: + adapter.inject() + return super().inject(parent) + + def eject(self) -> None: + for adapter in self.sub_adapters: + adapter.eject() + super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/self_attention_injection.py b/src/refiners/foundationals/latent_diffusion/self_attention_injection.py index 2f02b5b..ccc6fa7 100644 --- a/src/refiners/foundationals/latent_diffusion/self_attention_injection.py +++ b/src/refiners/foundationals/latent_diffusion/self_attention_injection.py @@ -27,10 +27,10 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]): self, target: SelfAttention, context: str, - sai: "SelfAttentionInjection", + style_cfg: float = 0.5, ) -> None: self.context = context - self._sai = [sai] # only to support setting `style_cfg` dynamically + self.style_cfg = style_cfg sa_guided = target.structural_copy() assert isinstance(sa_guided[0], Parallel) @@ -50,62 +50,26 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]): ) def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor: - style_cfg = self._sai[0].style_cfg - x[0] = style_cfg * x[0] + (1.0 - style_cfg) * unguided_unconditioned_x + x[0] = self.style_cfg * x[0] + (1.0 - self.style_cfg) * unguided_unconditioned_x return x -class SelfAttentionInjection(Passthrough): - # TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance - - def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None: - # the style_cfg is the weight of the guide in unconditionned diffusion. - # This value is recommended to be 0.5 on the sdwebui repo. - self.style_cfg = style_cfg - self._adapters: list[ReferenceOnlyControlAdapter] = [] - self._unet = [unet] - - guide_unet = unet.structural_copy() +class SelfAttentionInjectionPassthrough(Passthrough): + def __init__(self, target: SD1UNet) -> None: + guide_unet = target.structural_copy() for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)): sa = attention_block.find(SelfAttention) assert sa is not None and sa.parent is not None SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject() - for i, attention_block in enumerate(unet.layers(CrossAttentionBlock)): - unet.set_context(f"self_attention_context_{i}", {"norm": None}) - - sa = attention_block.find(SelfAttention) - assert sa is not None and sa.parent is not None - - self._adapters.append(ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", sai=self)) - super().__init__( - Lambda(self.copy_diffusion_context), + Lambda(self._copy_diffusion_context), UseContext("self_attention_injection", "guide"), guide_unet, - Lambda(self.restore_diffusion_context), + Lambda(self._restore_diffusion_context), ) - @property - def unet(self): - return self._unet[0] - - def inject(self) -> None: - assert self not in self._unet[0], f"{self} is already injected" - for adapter in self._adapters: - adapter.inject() - self.unet.insert(0, self) - - def eject(self) -> None: - assert self.unet[0] == self, f"{self} is not the first element of target UNet" - for adapter in self._adapters: - adapter.eject() - self.unet.pop(0) - - def set_controlnet_condition(self, condition: Tensor) -> None: - self.set_context("self_attention_injection", {"guide": condition}) - - def copy_diffusion_context(self, x: Tensor) -> Tensor: + def _copy_diffusion_context(self, x: Tensor) -> Tensor: # This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used) self.set_context( "self_attention_residuals_buffer", @@ -117,7 +81,7 @@ class SelfAttentionInjection(Passthrough): ) return x - def restore_diffusion_context(self, x: Tensor) -> Tensor: + def _restore_diffusion_context(self, x: Tensor) -> Tensor: self.set_context( "unet", { @@ -126,5 +90,50 @@ class SelfAttentionInjection(Passthrough): ) return x + +class SelfAttentionInjection(Chain, Adapter[SD1UNet]): + # TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance + + def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None: + # the style_cfg is the weight of the guide in unconditionned diffusion. + # This value is recommended to be 0.5 on the sdwebui repo. + + self.sub_adapters: list[ReferenceOnlyControlAdapter] = [] + self._passthrough: list[SelfAttentionInjectionPassthrough] = [ + SelfAttentionInjectionPassthrough(target) + ] # not registered by PyTorch + + with self.setup_adapter(target): + super().__init__(target) + + for i, attention_block in enumerate(target.layers(CrossAttentionBlock)): + self.set_context(f"self_attention_context_{i}", {"norm": None}) + + sa = attention_block.find(SelfAttention) + assert sa is not None and sa.parent is not None + + self.sub_adapters.append( + ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg) + ) + + def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection": + passthrough = self._passthrough[0] + assert passthrough not in self.target, f"{passthrough} is already injected" + for adapter in self.sub_adapters: + adapter.inject() + self.target.insert(0, passthrough) + return super().inject(parent) + + def eject(self) -> None: + passthrough = self._passthrough[0] + assert self.target[0] == passthrough, f"{passthrough} is not the first element of target UNet" + for adapter in self.sub_adapters: + adapter.eject() + self.target.pop(0) + super().eject() + + def set_controlnet_condition(self, condition: Tensor) -> None: + self.set_context("self_attention_injection", {"guide": condition}) + def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection": raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.") diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py index bf93fe0..65e7e72 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py @@ -3,11 +3,11 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( StableDiffusion_1, StableDiffusion_1_Inpainting, ) -from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet +from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter __all__ = [ "StableDiffusion_1", "StableDiffusion_1_Inpainting", "SD1UNet", - "SD1Controlnet", + "SD1ControlnetAdapter", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 6084047..029e584 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -1,11 +1,13 @@ from refiners.fluxion.context import Contexts from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( + SD1UNet, DownBlocks, MiddleBlock, ResidualBlock, TimestepEncoder, ) +from refiners.adapters.adapter import Adapter from refiners.adapters.range_adapter import RangeAdapter2d from typing import cast, Iterable from torch import Tensor, device as Device, dtype as DType @@ -69,10 +71,12 @@ class ConditionEncoder(Chain): ) -class SD1Controlnet(Passthrough): - structural_attrs = ["name", "scale"] +class Controlnet(Passthrough): + structural_attrs = ["scale"] - def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None: + def __init__( + self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None + ) -> None: """Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet. Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals @@ -80,8 +84,7 @@ class SD1Controlnet(Passthrough): It has to use the same context as the UNet: `unet` and `sampling`. """ - self.name = name - self.scale: float = 1.0 + self.scale = scale super().__init__( TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting @@ -102,15 +105,14 @@ class SD1Controlnet(Passthrough): ) for residual_block in self.layers(ResidualBlock): chain = residual_block.Chain - range_adapter = RangeAdapter2d( + RangeAdapter2d( target=chain.Conv2d_1, channels=residual_block.out_channels, embedding_dim=1280, - context_key=f"timestep_embedding_{self.name}", + context_key=f"timestep_embedding_{name}", device=device, dtype=dtype, - ) - range_adapter.inject(chain) + ).inject(chain) for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)): assert hasattr(block[0], "out_channels"), ( "The first block of every subchain in DownBlocks is expected to respond to `out_channels`," @@ -132,14 +134,6 @@ class SD1Controlnet(Passthrough): ) ) - def init_context(self) -> Contexts: - return { - "unet": {"residuals": [0.0] * 13}, - "sampling": {"shapes": []}, - "controlnet": {f"condition_{self.name}": None}, - "range_adapter": {f"timestep_embedding_{self.name}": None}, - } - def _store_nth_residual(self, n: int): def _store_residual(x: Tensor): residuals = self.use_context("unet")["residuals"] @@ -148,8 +142,39 @@ class SD1Controlnet(Passthrough): return _store_residual + +class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): + def __init__( + self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None + ) -> None: + self.name = name + + controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype) + if weights is not None: + controlnet.load_state_dict(weights) + self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch + + with self.setup_adapter(target): + super().__init__(target) + + def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter": + controlnet = self._controlnet[0] + assert controlnet not in self.target, f"{controlnet} is already injected" + self.target.insert(0, controlnet) + return super().inject(parent) + + def eject(self) -> None: + self.target.remove(self._controlnet[0]) + super().eject() + + def init_context(self) -> Contexts: + return {"controlnet": {f"condition_{self.name}": None}} + + def set_scale(self, scale: float) -> None: + self._controlnet[0].scale = scale + def set_controlnet_condition(self, condition: Tensor) -> None: self.set_context("controlnet", {f"condition_{self.name}": condition}) - def set_scale(self, scale: float) -> None: - self.scale = scale + def structural_copy(self: "SD1ControlnetAdapter") -> "SD1ControlnetAdapter": + raise RuntimeError("Controlnet cannot be copied, eject it first.") diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 314fd86..8d4e15e 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -278,15 +278,14 @@ class SD1UNet(fl.Chain): ) for residual_block in self.layers(ResidualBlock): chain = residual_block.Chain - range_adapter = RangeAdapter2d( + RangeAdapter2d( target=chain.Conv2d_1, channels=residual_block.out_channels, embedding_dim=1280, context_key="timestep_embedding", device=device, dtype=dtype, - ) - range_adapter.inject(chain) + ).inject(chain) for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)): block.append(ResidualAccumulator(n)) for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)): diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py index 4123b11..5d159f4 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/text_encoder.py @@ -70,12 +70,11 @@ class DoubleTextEncoder(fl.Chain): ) -> None: text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype) text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype) - text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection) super().__init__( fl.Parallel(text_encoder_l[:-2], text_encoder_g), fl.Lambda(func=self.concatenate_embeddings), ) - text_encoder_with_pooling.inject(parent=self.Parallel) + TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel) def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: return super().__call__(text) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 40b0819..192e5b3 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -261,15 +261,14 @@ class SDXLUNet(fl.Chain): ) for residual_block in self.layers(ResidualBlock): chain = residual_block.Chain - range_adapter = RangeAdapter2d( + RangeAdapter2d( target=chain.Conv2d_1, channels=residual_block.out_channels, embedding_dim=1280, context_key="timestep_embedding", device=device, dtype=dtype, - ) - range_adapter.inject(chain) + ).inject(chain) for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)): block.append(module=ResidualAccumulator(n=n)) for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)): diff --git a/src/refiners/training_utils/dropout.py b/src/refiners/training_utils/dropout.py index 45ac588..7b1a12c 100644 --- a/src/refiners/training_utils/dropout.py +++ b/src/refiners/training_utils/dropout.py @@ -164,8 +164,7 @@ def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None: assert not ( isinstance(parent, Dropout) or isinstance(parent, GyroDropout) ), f"{linear} already has a dropout layer" - adapter = DropoutAdapter(target=linear, probability=probability) - adapter.inject(parent) + DropoutAdapter(target=linear, probability=probability).inject(parent) def apply_gyro_dropout( @@ -181,14 +180,13 @@ def apply_gyro_dropout( assert not ( isinstance(parent, Dropout) or isinstance(parent, GyroDropout) ), f"{linear} already has a dropout layer" - adapter = GyroDropoutAdapter( + GyroDropoutAdapter( target=linear, probability=probability, total_subnetworks=total_subnetworks, concurrent_subnetworks=concurrent_subnetworks, iters_per_epoch=iters_per_epoch, - ) - adapter.inject(parent) + ).inject(parent) ConfigType = TypeVar("ConfigType", bound="BaseConfig") diff --git a/tests/adapters/test_adapter.py b/tests/adapters/test_adapter.py index 7a3426c..8f4ee52 100644 --- a/tests/adapters/test_adapter.py +++ b/tests/adapters/test_adapter.py @@ -24,9 +24,8 @@ def test_weighted_module_adapter_insertion(chain: Chain): parent = chain.Chain adaptee = parent.Linear - adapter = DummyLinearAdapter(adaptee) + adapter = DummyLinearAdapter(adaptee).inject(parent) - adapter.inject(parent) assert adapter.parent == parent assert adapter in iter(parent) assert adaptee not in iter(parent) @@ -61,8 +60,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain): parent = chain.Chain adaptee = parent.Linear - adapter = DummyLinearAdapter(adaptee) - adapter.inject(parent) + DummyLinearAdapter(adaptee).inject(parent) clone = chain.structural_copy() cloned_adapter = clone.Chain.DummyLinearAdapter @@ -72,8 +70,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain): def test_chain_adapter_structural_copy(chain: Chain): # Chain adapters cannot be copied by default. - adapter = DummyChainAdapter(chain.Chain) - adapter.inject() + adapter = DummyChainAdapter(chain.Chain).inject() with pytest.raises(RuntimeError): chain.structural_copy() diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py index 148d506..6255b88 100644 --- a/tests/adapters/test_lora.py +++ b/tests/adapters/test_lora.py @@ -1,9 +1,9 @@ -from refiners.adapters.lora import Lora, LoraAdapter +from refiners.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter from torch import randn, allclose import refiners.fluxion.layers as fl -def test_lora() -> None: +def test_single_lora_adapter() -> None: chain = fl.Chain( fl.Chain( fl.Linear(in_features=1, out_features=1), @@ -14,8 +14,7 @@ def test_lora() -> None: x = randn(1, 1) y = chain(x) - lora_adapter = LoraAdapter(chain.Chain.Linear_1) - lora_adapter.inject(chain.Chain) + lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain) assert isinstance(lora_adapter[1], Lora) assert allclose(input=chain(x), other=y) @@ -26,4 +25,18 @@ def test_lora() -> None: assert len(chain) == 2 lora_adapter.inject(chain.Chain) - assert isinstance(chain.Chain[0], LoraAdapter) + assert isinstance(chain.Chain[0], SingleLoraAdapter) + + +def test_lora_adapter() -> None: + chain = fl.Chain( + fl.Chain( + fl.Linear(in_features=1, out_features=1), + fl.Linear(in_features=1, out_features=1), + ), + fl.Linear(in_features=1, out_features=2), + ) + + LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() + + assert len(list(chain.layers(Lora))) == 3 diff --git a/tests/adapters/test_range_adapter.py b/tests/adapters/test_range_adapter.py index 870c8a3..030ac84 100644 --- a/tests/adapters/test_range_adapter.py +++ b/tests/adapters/test_range_adapter.py @@ -15,8 +15,7 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG- chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype)) adaptee = chain.RangeEncoder.Linear_1 - adapter = DummyLinearAdapter(adaptee) - adapter.inject(chain.RangeEncoder) + adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder) assert adapter.parent == chain.RangeEncoder diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index efd9649..65ff01a 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -12,9 +12,9 @@ from refiners.foundationals.latent_diffusion import ( StableDiffusion_1, StableDiffusion_1_Inpainting, SD1UNet, - SD1Controlnet, + SD1ControlnetAdapter, ) -from refiners.foundationals.latent_diffusion.lora import LoraWeights +from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection from refiners.foundationals.clip.concepts import ConceptExtender @@ -57,6 +57,11 @@ def expected_image_std_inpainting(ref_path: Path) -> Image.Image: return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB") +@pytest.fixture +def expected_image_controlnet_stack(ref_path: Path) -> Image.Image: + return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB") + + @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) def controlnet_data( ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest @@ -85,6 +90,15 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str, return cn_name, condition_image, expected_image, weights_path +@pytest.fixture(scope="module") +def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]: + cn_name = "depth" + condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB") + expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB") + weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors" + return cn_name, condition_image, expected_image, weights_path + + @pytest.fixture(scope="module") def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]: expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") @@ -453,11 +467,9 @@ def test_diffusion_controlnet( sd15.set_num_inference_steps(n_steps) - controlnet_state_dict = load_from_safetensors(cn_weights_path) - controlnet = SD1Controlnet(name=cn_name, device=test_device) - controlnet.load_state_dict(controlnet_state_dict) - controlnet.set_scale(0.5) - sd15.unet.insert(0, controlnet) + controlnet = SD1ControlnetAdapter( + sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) + ).inject() cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) @@ -502,11 +514,9 @@ def test_diffusion_controlnet_structural_copy( sd15.set_num_inference_steps(n_steps) - controlnet_state_dict = load_from_safetensors(cn_weights_path) - controlnet = SD1Controlnet(name=cn_name, device=test_device) - controlnet.load_state_dict(controlnet_state_dict) - controlnet.set_scale(0.5) - sd15.unet.insert(0, controlnet) + controlnet = SD1ControlnetAdapter( + sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) + ).inject() cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) @@ -550,11 +560,9 @@ def test_diffusion_controlnet_float16( sd15.set_num_inference_steps(n_steps) - controlnet_state_dict = load_from_safetensors(cn_weights_path) - controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16) - controlnet.load_state_dict(controlnet_state_dict) - controlnet.set_scale(0.5) - sd15.unet.insert(0, controlnet) + controlnet = SD1ControlnetAdapter( + sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path) + ).inject() cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16) @@ -575,6 +583,64 @@ def test_diffusion_controlnet_float16( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@torch.no_grad() +def test_diffusion_controlnet_stack( + sd15_std: StableDiffusion_1, + controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], + controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], + expected_image_controlnet_stack: Image.Image, + test_device: torch.device, +): + sd15 = sd15_std + n_steps = 30 + + _, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth + _, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny + + if not canny_cn_weights_path.is_file(): + warn(f"could not find weights at {canny_cn_weights_path}, skipping") + pytest.skip(allow_module_level=True) + + if not depth_cn_weights_path.is_file(): + warn(f"could not find weights at {depth_cn_weights_path}, skipping") + pytest.skip(allow_module_level=True) + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + + with torch.no_grad(): + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_num_inference_steps(n_steps) + + depth_controlnet = SD1ControlnetAdapter( + sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path) + ).inject() + canny_controlnet = SD1ControlnetAdapter( + sd15.unet, name="canny", scale=0.7, weights=load_from_safetensors(canny_cn_weights_path) + ).inject() + + depth_cn_condition = image_to_tensor(depth_condition_image.convert("RGB"), device=test_device) + canny_cn_condition = image_to_tensor(canny_condition_image.convert("RGB"), device=test_device) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + + with torch.no_grad(): + for step in sd15.steps: + depth_controlnet.set_controlnet_condition(depth_cn_condition) + canny_controlnet.set_controlnet_condition(canny_cn_condition) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) + + @torch.no_grad() def test_diffusion_lora( sd15_std: StableDiffusion_1, @@ -597,8 +663,7 @@ def test_diffusion_lora( sd15.set_num_inference_steps(n_steps) - lora_weights = LoraWeights(lora_weights_path, device=test_device) - lora_weights.patch(sd15, scale=1.0) + SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject() manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) @@ -629,8 +694,7 @@ def test_diffusion_refonly( with torch.no_grad(): clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - sai = SelfAttentionInjection(sd15.unet) - sai.inject() + sai = SelfAttentionInjection(sd15.unet).inject() guide = sd15.lda.encode_image(condition_image_refonly) guide = torch.cat((guide, guide)) @@ -671,29 +735,26 @@ def test_diffusion_inpainting_refonly( with torch.no_grad(): clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - sai = SelfAttentionInjection(sd15.unet) - sai.inject() + sai = SelfAttentionInjection(sd15.unet).inject() sd15.set_num_inference_steps(n_steps) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) - refonly_guide = sd15.lda.encode_image(scene_image_inpainting_refonly) - refonly_guide = torch.cat((refonly_guide, refonly_guide)) + guide = sd15.lda.encode_image(scene_image_inpainting_refonly) + guide = torch.cat((guide, guide)) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) with torch.no_grad(): for step in sd15.steps: - refonly_noise = torch.randn_like(refonly_guide) - refonly_noised_guide = sd15.scheduler.add_noise(refonly_guide, refonly_noise, step) + noise = torch.randn_like(guide) + noised_guide = sd15.scheduler.add_noise(guide, noise, step) # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support # inpaint variation models") - refonly_noised_guide = torch.cat( - [refonly_noised_guide, torch.zeros_like(refonly_noised_guide)[:, 0:1, :, :], refonly_guide], dim=1 - ) + noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) - sai.set_controlnet_condition(refonly_noised_guide) + sai.set_controlnet_condition(noised_guide) x = sd15( x, step=step, diff --git a/tests/e2e/test_diffusion_ref/expected_controlnet_stack.png b/tests/e2e/test_diffusion_ref/expected_controlnet_stack.png new file mode 100644 index 0000000..6f48628 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_controlnet_stack.png differ diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py new file mode 100644 index 0000000..a9de8dc --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -0,0 +1,85 @@ +from typing import Iterator + +import torch +import pytest + +import refiners.fluxion.layers as fl +from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet + + +@pytest.fixture(scope="module", params=[True, False]) +def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]: + with_parent: bool = request.param + unet = SD1UNet(in_channels=9, clip_embedding_dim=768) + if with_parent: + fl.Chain(unet) + yield unet + + +@torch.no_grad() +def test_single_controlnet(unet: SD1UNet) -> None: + original_parent = unet.parent + cn = SD1ControlnetAdapter(unet, name="cn") + + assert unet.parent == original_parent + assert len(list(unet.walk(Controlnet))) == 0 + + with pytest.raises(ValueError) as exc: + cn.eject() + assert "not in" in str(exc.value) + + cn.inject() + assert unet.parent == cn + assert len(list(unet.walk(Controlnet))) == 1 + + with pytest.raises(AssertionError) as exc: + cn.inject() + assert "already injected" in str(exc.value) + + cn.eject() + assert unet.parent == original_parent + assert len(list(unet.walk(Controlnet))) == 0 + + +@torch.no_grad() +def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: + original_parent = unet.parent + cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() + cn2 = SD1ControlnetAdapter(unet, name="cn2").inject() + + assert unet.parent == cn2 + assert unet in cn2 + assert unet not in cn1 + assert cn2.parent == cn1 + assert cn2 in cn1 + assert cn1.parent == original_parent + assert len(list(unet.walk(Controlnet))) == 2 + assert cn1.target == unet + assert cn1.lookup_actual_target() == cn2 + + cn2.eject() + assert unet.parent == cn1 + assert unet in cn2 + assert cn2 not in cn1 + assert unet in cn1 + assert len(list(unet.walk(Controlnet))) == 1 + + cn1.eject() + assert unet.parent == original_parent + assert len(list(unet.walk(Controlnet))) == 0 + + +@torch.no_grad() +def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: + original_parent = unet.parent + cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() + cn2 = SD1ControlnetAdapter(unet, name="cn2").inject() + + cn1.eject() + assert cn2.parent == original_parent + assert unet.parent == cn2 + + cn2.eject() + assert unet.parent == original_parent + assert len(list(unet.walk(Controlnet))) == 0 diff --git a/tests/foundationals/latent_diffusion/test_lora.py b/tests/foundationals/latent_diffusion/test_lora.py deleted file mode 100644 index 4df6140..0000000 --- a/tests/foundationals/latent_diffusion/test_lora.py +++ /dev/null @@ -1,16 +0,0 @@ -from refiners.adapters.lora import Lora -from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget -import refiners.fluxion.layers as fl - - -def test_lora_target_self() -> None: - chain = fl.Chain( - fl.Chain( - fl.Linear(in_features=1, out_features=1), - fl.Linear(in_features=1, out_features=1), - ), - fl.Linear(in_features=1, out_features=2), - ) - apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0) - - assert len(list(chain.layers(Lora))) == 3 diff --git a/tests/foundationals/latent_diffusion/test_self_attention_injection.py b/tests/foundationals/latent_diffusion/test_self_attention_injection.py new file mode 100644 index 0000000..84e66db --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_self_attention_injection.py @@ -0,0 +1,48 @@ +import torch +import pytest + + +from refiners.foundationals.latent_diffusion import SD1UNet +from refiners.foundationals.latent_diffusion.self_attention_injection import ( + SelfAttentionInjection, + SaveLayerNormAdapter, + ReferenceOnlyControlAdapter, + SelfAttentionInjectionPassthrough, +) +from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock + + +@torch.no_grad() +def test_sai_inject_eject() -> None: + unet = SD1UNet(in_channels=9, clip_embedding_dim=768) + sai = SelfAttentionInjection(unet) + + nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock))) + assert nb_cross_attention_blocks > 0 + + assert unet.parent is None + assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0 + assert len(list(unet.walk(SaveLayerNormAdapter))) == 0 + assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0 + + with pytest.raises(AssertionError) as exc: + sai.eject() + assert "not the first element" in str(exc.value) + + sai.inject() + + assert unet.parent == sai + assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1 + assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks + assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks + + with pytest.raises(AssertionError) as exc: + sai.inject() + assert "already injected" in str(exc.value) + + sai.eject() + + assert unet.parent is None + assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0 + assert len(list(unet.walk(SaveLayerNormAdapter))) == 0 + assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0