support injecting several LoRAs simultaneously

This commit is contained in:
Pierre Chapuis 2023-09-04 15:33:40 +02:00
parent 88efa117bf
commit 864937a776
4 changed files with 96 additions and 27 deletions

View file

@ -72,10 +72,13 @@ class Adapter(Generic[T]):
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage] self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
return self return self
if self.target not in iter(parent): # In general, `true_parent` is `parent`. We do this to support multiple adaptation,
# i.e. initializing two adapters before injecting them.
true_parent = parent.find_parent(self.target)
if true_parent is None:
raise ValueError(f"{self.target} is not in {parent}") raise ValueError(f"{self.target} is not in {parent}")
parent.replace( true_parent.replace(
old_module=self.target, old_module=self.target,
new_module=self, new_module=self,
old_module_parent=target_parent, old_module_parent=target_parent,

View file

@ -1,6 +1,6 @@
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator, Callable
from torch import Tensor from torch import Tensor
@ -8,7 +8,7 @@ import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import SingleLoraAdapter, LoraAdapter from refiners.fluxion.adapters.lora import LoraAdapter, Lora
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
@ -47,6 +47,23 @@ class LoraTarget(str, Enum):
return TransformerLayer return TransformerLayer
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt other LoRAs
raise StopIteration
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, k)
return f
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
for m, p in module.walk(_predicate(fl.Linear)):
assert isinstance(m, fl.Linear)
yield (m, p)
def lora_targets( def lora_targets(
module: fl.Chain, module: fl.Chain,
target: LoraTarget | list[LoraTarget], target: LoraTarget | list[LoraTarget],
@ -56,29 +73,13 @@ def lora_targets(
yield from lora_targets(module, t) yield from lora_targets(module, t)
return return
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
if isinstance(module, SD1UNet):
def predicate(m: fl.Module, p: fl.Chain) -> bool:
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, lookup_class)
else:
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(m, lookup_class)
if target == LoraTarget.Self: if target == LoraTarget.Self:
for m, p in module.walk(predicate): yield from _iter_linears(module)
assert isinstance(m, fl.Linear)
yield (m, p)
return return
for layer, _ in module.walk(predicate): for layer, _ in module.walk(_predicate(target.get_class())):
for t in layer.walk(fl.Linear): assert isinstance(layer, fl.Chain)
yield t yield from _iter_linears(layer)
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
@ -101,8 +102,6 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
if not (model_targets := sub_targets.get(model_name, [])): if not (model_targets := sub_targets.get(model_name, [])):
continue continue
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name) model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
if model.find(SingleLoraAdapter):
raise NotImplementedError(f"{model} already contains LoRA layers")
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append( self.sub_adapters.append(

View file

@ -37,6 +37,32 @@ def test_lora_adapter() -> None:
fl.Linear(in_features=1, out_features=2), fl.Linear(in_features=1, out_features=2),
) )
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject() # create and inject twice
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 3 assert len(list(chain.layers(Lora))) == 3
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 6
# ejection in forward order
a1.eject()
assert len(list(chain.layers(Lora))) == 3
a2.eject()
assert len(list(chain.layers(Lora))) == 0
# create twice then inject twice
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
a1.inject()
a2.inject()
assert len(list(chain.layers(Lora))) == 6
# ejection in reverse order
a2.eject()
assert len(list(chain.layers(Lora))) == 3
a1.eject()
assert len(list(chain.layers(Lora))) == 0

View file

@ -681,6 +681,47 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_lora_twice(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
expected_image, lora_weights_path = lora_data_pokemon
if not lora_weights_path.is_file():
warn(f"could not find weights at {lora_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @torch.no_grad()
def test_diffusion_refonly( def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1, sd15_ddim: StableDiffusion_1,