mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
support injecting several LoRAs simultaneously
This commit is contained in:
parent
88efa117bf
commit
864937a776
|
@ -72,10 +72,13 @@ class Adapter(Generic[T]):
|
|||
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
|
||||
return self
|
||||
|
||||
if self.target not in iter(parent):
|
||||
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
|
||||
# i.e. initializing two adapters before injecting them.
|
||||
true_parent = parent.find_parent(self.target)
|
||||
if true_parent is None:
|
||||
raise ValueError(f"{self.target} is not in {parent}")
|
||||
|
||||
parent.replace(
|
||||
true_parent.replace(
|
||||
old_module=self.target,
|
||||
new_module=self,
|
||||
old_module_parent=target_parent,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Callable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -8,7 +8,7 @@ import refiners.fluxion.layers as fl
|
|||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.adapters.lora import SingleLoraAdapter, LoraAdapter
|
||||
from refiners.fluxion.adapters.lora import LoraAdapter, Lora
|
||||
|
||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
|
@ -47,6 +47,23 @@ class LoraTarget(str, Enum):
|
|||
return TransformerLayer
|
||||
|
||||
|
||||
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||
if isinstance(m, Lora): # do not adapt other LoRAs
|
||||
raise StopIteration
|
||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||
raise StopIteration
|
||||
return isinstance(m, k)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
for m, p in module.walk(_predicate(fl.Linear)):
|
||||
assert isinstance(m, fl.Linear)
|
||||
yield (m, p)
|
||||
|
||||
|
||||
def lora_targets(
|
||||
module: fl.Chain,
|
||||
target: LoraTarget | list[LoraTarget],
|
||||
|
@ -56,29 +73,13 @@ def lora_targets(
|
|||
yield from lora_targets(module, t)
|
||||
return
|
||||
|
||||
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
|
||||
|
||||
if isinstance(module, SD1UNet):
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||
raise StopIteration
|
||||
return isinstance(m, lookup_class)
|
||||
|
||||
else:
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
return isinstance(m, lookup_class)
|
||||
|
||||
if target == LoraTarget.Self:
|
||||
for m, p in module.walk(predicate):
|
||||
assert isinstance(m, fl.Linear)
|
||||
yield (m, p)
|
||||
yield from _iter_linears(module)
|
||||
return
|
||||
|
||||
for layer, _ in module.walk(predicate):
|
||||
for t in layer.walk(fl.Linear):
|
||||
yield t
|
||||
for layer, _ in module.walk(_predicate(target.get_class())):
|
||||
assert isinstance(layer, fl.Chain)
|
||||
yield from _iter_linears(layer)
|
||||
|
||||
|
||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||
|
@ -101,8 +102,6 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
|||
if not (model_targets := sub_targets.get(model_name, [])):
|
||||
continue
|
||||
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
||||
if model.find(SingleLoraAdapter):
|
||||
raise NotImplementedError(f"{model} already contains LoRA layers")
|
||||
|
||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||
self.sub_adapters.append(
|
||||
|
|
|
@ -37,6 +37,32 @@ def test_lora_adapter() -> None:
|
|||
fl.Linear(in_features=1, out_features=2),
|
||||
)
|
||||
|
||||
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||
# create and inject twice
|
||||
|
||||
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
|
||||
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||
assert len(list(chain.layers(Lora))) == 6
|
||||
|
||||
# ejection in forward order
|
||||
|
||||
a1.eject()
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
a2.eject()
|
||||
assert len(list(chain.layers(Lora))) == 0
|
||||
|
||||
# create twice then inject twice
|
||||
|
||||
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
||||
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
||||
a1.inject()
|
||||
a2.inject()
|
||||
assert len(list(chain.layers(Lora))) == 6
|
||||
|
||||
# ejection in reverse order
|
||||
|
||||
a2.eject()
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
a1.eject()
|
||||
assert len(list(chain.layers(Lora))) == 0
|
||||
|
|
|
@ -681,6 +681,47 @@ def test_diffusion_lora(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_lora_twice(
|
||||
sd15_std: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
expected_image, lora_weights_path = lora_data_pokemon
|
||||
|
||||
if not lora_weights_path.is_file():
|
||||
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
prompt = "a cute cat"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_refonly(
|
||||
sd15_ddim: StableDiffusion_1,
|
||||
|
|
Loading…
Reference in a new issue