mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
support injecting several LoRAs simultaneously
This commit is contained in:
parent
88efa117bf
commit
864937a776
|
@ -72,10 +72,13 @@ class Adapter(Generic[T]):
|
||||||
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
|
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
|
||||||
return self
|
return self
|
||||||
|
|
||||||
if self.target not in iter(parent):
|
# In general, `true_parent` is `parent`. We do this to support multiple adaptation,
|
||||||
|
# i.e. initializing two adapters before injecting them.
|
||||||
|
true_parent = parent.find_parent(self.target)
|
||||||
|
if true_parent is None:
|
||||||
raise ValueError(f"{self.target} is not in {parent}")
|
raise ValueError(f"{self.target} is not in {parent}")
|
||||||
|
|
||||||
parent.replace(
|
true_parent.replace(
|
||||||
old_module=self.target,
|
old_module=self.target,
|
||||||
new_module=self,
|
new_module=self,
|
||||||
old_module_parent=target_parent,
|
old_module_parent=target_parent,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator, Callable
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||||
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.adapters.lora import SingleLoraAdapter, LoraAdapter
|
from refiners.fluxion.adapters.lora import LoraAdapter, Lora
|
||||||
|
|
||||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
|
@ -47,6 +47,23 @@ class LoraTarget(str, Enum):
|
||||||
return TransformerLayer
|
return TransformerLayer
|
||||||
|
|
||||||
|
|
||||||
|
def _predicate(k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
|
||||||
|
def f(m: fl.Module, _: fl.Chain) -> bool:
|
||||||
|
if isinstance(m, Lora): # do not adapt other LoRAs
|
||||||
|
raise StopIteration
|
||||||
|
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||||
|
raise StopIteration
|
||||||
|
return isinstance(m, k)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def _iter_linears(module: fl.Chain) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||||
|
for m, p in module.walk(_predicate(fl.Linear)):
|
||||||
|
assert isinstance(m, fl.Linear)
|
||||||
|
yield (m, p)
|
||||||
|
|
||||||
|
|
||||||
def lora_targets(
|
def lora_targets(
|
||||||
module: fl.Chain,
|
module: fl.Chain,
|
||||||
target: LoraTarget | list[LoraTarget],
|
target: LoraTarget | list[LoraTarget],
|
||||||
|
@ -56,29 +73,13 @@ def lora_targets(
|
||||||
yield from lora_targets(module, t)
|
yield from lora_targets(module, t)
|
||||||
return
|
return
|
||||||
|
|
||||||
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
|
|
||||||
|
|
||||||
if isinstance(module, SD1UNet):
|
|
||||||
|
|
||||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
|
||||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
|
||||||
raise StopIteration
|
|
||||||
return isinstance(m, lookup_class)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
|
||||||
return isinstance(m, lookup_class)
|
|
||||||
|
|
||||||
if target == LoraTarget.Self:
|
if target == LoraTarget.Self:
|
||||||
for m, p in module.walk(predicate):
|
yield from _iter_linears(module)
|
||||||
assert isinstance(m, fl.Linear)
|
|
||||||
yield (m, p)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
for layer, _ in module.walk(predicate):
|
for layer, _ in module.walk(_predicate(target.get_class())):
|
||||||
for t in layer.walk(fl.Linear):
|
assert isinstance(layer, fl.Chain)
|
||||||
yield t
|
yield from _iter_linears(layer)
|
||||||
|
|
||||||
|
|
||||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||||
|
@ -101,8 +102,6 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||||
if not (model_targets := sub_targets.get(model_name, [])):
|
if not (model_targets := sub_targets.get(model_name, [])):
|
||||||
continue
|
continue
|
||||||
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
model = getattr(target, "clip_text_encoder" if model_name == "text_encoder" else model_name)
|
||||||
if model.find(SingleLoraAdapter):
|
|
||||||
raise NotImplementedError(f"{model} already contains LoRA layers")
|
|
||||||
|
|
||||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||||
self.sub_adapters.append(
|
self.sub_adapters.append(
|
||||||
|
|
|
@ -37,6 +37,32 @@ def test_lora_adapter() -> None:
|
||||||
fl.Linear(in_features=1, out_features=2),
|
fl.Linear(in_features=1, out_features=2),
|
||||||
)
|
)
|
||||||
|
|
||||||
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
# create and inject twice
|
||||||
|
|
||||||
|
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||||
assert len(list(chain.layers(Lora))) == 3
|
assert len(list(chain.layers(Lora))) == 3
|
||||||
|
|
||||||
|
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||||
|
assert len(list(chain.layers(Lora))) == 6
|
||||||
|
|
||||||
|
# ejection in forward order
|
||||||
|
|
||||||
|
a1.eject()
|
||||||
|
assert len(list(chain.layers(Lora))) == 3
|
||||||
|
a2.eject()
|
||||||
|
assert len(list(chain.layers(Lora))) == 0
|
||||||
|
|
||||||
|
# create twice then inject twice
|
||||||
|
|
||||||
|
a1 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
||||||
|
a2 = LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0)
|
||||||
|
a1.inject()
|
||||||
|
a2.inject()
|
||||||
|
assert len(list(chain.layers(Lora))) == 6
|
||||||
|
|
||||||
|
# ejection in reverse order
|
||||||
|
|
||||||
|
a2.eject()
|
||||||
|
assert len(list(chain.layers(Lora))) == 3
|
||||||
|
a1.eject()
|
||||||
|
assert len(list(chain.layers(Lora))) == 0
|
||||||
|
|
|
@ -681,6 +681,47 @@ def test_diffusion_lora(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_diffusion_lora_twice(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
expected_image, lora_weights_path = lora_data_pokemon
|
||||||
|
|
||||||
|
if not lora_weights_path.is_file():
|
||||||
|
warn(f"could not find weights at {lora_weights_path}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
prompt = "a cute cat"
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.4).inject()
|
||||||
|
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=0.6).inject()
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_diffusion_refonly(
|
def test_diffusion_refonly(
|
||||||
sd15_ddim: StableDiffusion_1,
|
sd15_ddim: StableDiffusion_1,
|
||||||
|
|
Loading…
Reference in a new issue