support injecting several LoRAs simultaneously

This commit is contained in:
Pierre Chapuis 2023-09-04 15:33:40 +02:00
parent 88efa117bf
commit 864937a776
4 changed files with 96 additions and 27 deletions

View file

@ -72,10 +72,13 @@ class Adapter(Generic[T]):
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
return self
if self.target not in iter(parent):
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
# i.e. initializing two adapters before injecting them.
true_parent = parent.find_parent(self.target)
if true_parent is None:
raise ValueError(f"{self.target} is not in {parent}")
parent.replace(
true_parent.replace(
old_module=self.target,
new_module=self,
old_module_parent=target_parent,

View file

@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import Iterator
from typing import Iterator, Callable
from torch import Tensor
@ -8,7 +8,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import SingleLoraAdapter, LoraAdapter
from refiners.fluxion.adapters.lora import LoraAdapter, Lora
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
@ -47,6 +47,23 @@ class LoraTarget(str, Enum):
return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets(
module: fl.Chain,
target: LoraTarget | list[LoraTarget],
@ -56,29 +73,13 @@ def lora_targets(
yield from lora_targets(module, t)
return
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
if isinstance(module, SD1UNet):
def predicate(m: fl.Module, p: fl.Chain) -> bool:
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, lookup_class)
else:
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(m, lookup_class)
if target == LoraTarget.Self:
for m, p in module.walk(predicate):
assert isinstance(m, fl.Linear)
yield (m, p)
yield from _iter_linears(module)
return
for layer, _ in module.walk(predicate):
for t in layer.walk(fl.Linear):
yield t
for layer, _ in module.walk(_predicate(target.get_class())):
assert isinstance(layer, fl.Chain)
yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
@ -101,8 +102,6 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
if not (model_targets := sub_targets.get(model_name, [])):
continue
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
if model.find(SingleLoraAdapter):
raise NotImplementedError(f"{model} already contains LoRA layers")
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(

View file

@ -37,6 +37,32 @@ def test_lora_adapter() -> None:
fl.Linear(in_features=1, out_features=2),
)
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
# create and inject twice
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 3
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 6
# ejection in forward order
a1.eject()
assert len(list(chain.layers(Lora))) == 3
a2.eject()
assert len(list(chain.layers(Lora))) == 0
# create twice then inject twice
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
a1.inject()
a2.inject()
assert len(list(chain.layers(Lora))) == 6
# ejection in reverse order
a2.eject()
assert len(list(chain.layers(Lora))) == 3
a1.eject()
assert len(list(chain.layers(Lora))) == 0

View file

@ -681,6 +681,47 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_lora_twice(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
expected_image, lora_weights_path = lora_data_pokemon
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1,