From 864937a7761e3ac96ac2af616fcef4b88c3235a0 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Mon, 4 Sep 2023 15:33:40 +0200 Subject: [PATCH] support injecting several LoRAs simultaneously --- src/refiners/fluxion/adapters/adapter.py | 7 ++- .../foundationals/latent_diffusion/lora.py | 47 +++++++++---------- tests/adapters/test_lora.py | 28 ++++++++++- tests/e2e/test_diffusion.py | 41 ++++++++++++++++ 4 files changed, 96 insertions(+), 27 deletions(-) diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index 5a853aa..ea838f9 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -72,10 +72,13 @@ class Adapter(Generic[T]): self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage] return self - if self.target not in iter(parent): + # In general, `true_parent` is `parent`. We do this to support multiple adaptation, + # i.e. initializing two adapters before injecting them. + true_parent = parent.find_parent(self.target) + if true_parent is None: raise ValueError(f"{self.target} is not in {parent}") - parent.replace( + true_parent.replace( old_module=self.target, new_module=self, old_module_parent=target_parent, diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index 48bb7f3..d0d52c3 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import Iterator +from typing import Iterator, Callable from torch import Tensor @@ -8,7 +8,7 @@ import refiners.fluxion.layers as fl from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors from refiners.fluxion.adapters.adapter import Adapter -from refiners.fluxion.adapters.lora import SingleLoraAdapter, LoraAdapter +from refiners.fluxion.adapters.lora import LoraAdapter, Lora from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d @@ -47,6 +47,23 @@ class LoraTarget(str, Enum): return TransformerLayer +def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]: + def f(m: fl.Module, _: fl.Chain) -> bool: + if isinstance(m, Lora): # do not adapt other LoRAs + raise StopIteration + if isinstance(m, Controlnet): # do not adapt Controlnet linears + raise StopIteration + return isinstance(m, k) + + return f + + +def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]: + for m, p in module.walk(_predicate(fl.Linear)): + assert isinstance(m, fl.Linear) + yield (m, p) + + def lora_targets( module: fl.Chain, target: LoraTarget | list[LoraTarget], @@ -56,29 +73,13 @@ def lora_targets( yield from lora_targets(module, t) return - lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class() - - if isinstance(module, SD1UNet): - - def predicate(m: fl.Module, p: fl.Chain) -> bool: - if isinstance(m, Controlnet): # do not adapt Controlnet linears - raise StopIteration - return isinstance(m, lookup_class) - - else: - - def predicate(m: fl.Module, p: fl.Chain) -> bool: - return isinstance(m, lookup_class) - if target == LoraTarget.Self: - for m, p in module.walk(predicate): - assert isinstance(m, fl.Linear) - yield (m, p) + yield from _iter_linears(module) return - for layer, _ in module.walk(predicate): - for t in layer.walk(fl.Linear): - yield t + for layer, _ in module.walk(_predicate(target.get_class())): + assert isinstance(layer, fl.Chain) + yield from _iter_linears(layer) class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): @@ -101,8 +102,6 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): if not (model_targets := sub_targets.get(model_name, [])): continue model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name) - if model.find(SingleLoraAdapter): - raise NotImplementedError(f"{model} already contains LoRA layers") lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None self.sub_adapters.append( diff --git a/tests/adapters/test_lora.py b/tests/adapters/test_lora.py index b73304d..1784bd9 100644 --- a/tests/adapters/test_lora.py +++ b/tests/adapters/test_lora.py @@ -37,6 +37,32 @@ def test_lora_adapter() -> None: fl.Linear(in_features=1, out_features=2), ) - LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() + # create and inject twice + a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() assert len(list(chain.layers(Lora))) == 3 + + a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() + assert len(list(chain.layers(Lora))) == 6 + + # ejection in forward order + + a1.eject() + assert len(list(chain.layers(Lora))) == 3 + a2.eject() + assert len(list(chain.layers(Lora))) == 0 + + # create twice then inject twice + + a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0) + a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0) + a1.inject() + a2.inject() + assert len(list(chain.layers(Lora))) == 6 + + # ejection in reverse order + + a2.eject() + assert len(list(chain.layers(Lora))) == 3 + a1.eject() + assert len(list(chain.layers(Lora))) == 0 diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index c728af2..f5bcc1a 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -681,6 +681,47 @@ def test_diffusion_lora( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) +@torch.no_grad() +def test_diffusion_lora_twice( + sd15_std: StableDiffusion_1, + lora_data_pokemon: tuple[Image.Image, Path], + test_device: torch.device, +): + sd15 = sd15_std + n_steps = 30 + + expected_image, lora_weights_path = lora_data_pokemon + + if not lora_weights_path.is_file(): + warn(f"could not find weights at {lora_weights_path}, skipping") + pytest.skip(allow_module_level=True) + + prompt = "a cute cat" + + with torch.no_grad(): + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + + sd15.set_num_inference_steps(n_steps) + + SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject() + SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject() + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + + with torch.no_grad(): + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) + + @torch.no_grad() def test_diffusion_refonly( sd15_ddim: StableDiffusion_1,