make high-level adapters Adapters

This generalizes the Adapter abstraction to higher-level
constructs such as high-level LoRA (targeting e.g. the
SD UNet), ControlNet and Reference-Only Control.

Some adapters now work by adapting child models with
"sub-adapters" that they inject / eject when needed.
This commit is contained in:
Pierre Chapuis 2023-08-31 10:40:01 +02:00
parent 7dc2e93cff
commit 0f476ea18b
24 changed files with 603 additions and 282 deletions

View file

@ -179,12 +179,11 @@ The `Adapter` API lets you **easily patch models** by injecting parameters in ta
E.g. to inject LoRA layers in all attention's linear layers: E.g. to inject LoRA layers in all attention's linear layers:
```python ```python
from refiners.adapters.lora import LoraAdapter from refiners.adapters.lora import SingleLoraAdapter
for layer in vit.layers(fl.Attention): for layer in vit.layers(fl.Attention):
for linear, parent in layer.walk(fl.Linear): for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(target=linear, rank=64) SingleLoraAdapter(target=linear, rank=64).inject(parent)
adapter.inject(parent)
# ... and load existing weights if the LoRAs are pretrained ... # ... and load existing weights if the LoRAs are pretrained ...
``` ```
@ -232,7 +231,7 @@ Step 3: run inference using the GPU:
```python ```python
from refiners.foundationals.latent_diffusion import StableDiffusion_1 from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import LoraWeights from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.fluxion.utils import load_from_safetensors, manual_seed from refiners.fluxion.utils import load_from_safetensors, manual_seed
import torch import torch
@ -242,9 +241,7 @@ sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors")
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors")) sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors")) sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
# This uses the LoraAdapter internally and takes care to inject it where it should SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
lora_weights.patch(sd15, scale=1.0)
prompt = "a cute cat" prompt = "a cute cat"

View file

@ -8,7 +8,7 @@ from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion import ( from refiners.foundationals.latent_diffusion import (
SD1UNet, SD1UNet,
SD1Controlnet, SD1ControlnetAdapter,
DPMSolver, DPMSolver,
) )
@ -21,13 +21,13 @@ class Args(argparse.Namespace):
@torch.no_grad() @torch.no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]: def convert(args: Args) -> dict[str, torch.Tensor]:
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
controlnet = SD1Controlnet(name="mycn") unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
controlnet = unet.Controlnet
condition = torch.randn(1, 3, 512, 512) condition = torch.randn(1, 3, 512, 512)
controlnet.set_controlnet_condition(condition=condition) adapter.set_controlnet_condition(condition=condition)
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
unet.insert(index=0, module=controlnet)
clip_text_embedding = torch.rand(1, 77, 768) clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

View file

@ -1,17 +1,20 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import cast from typing import cast
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn.init import zeros_ from torch.nn.init import zeros_
from torch.nn import Parameter as TorchParameter from torch.nn import Parameter as TorchParameter
from diffusers import DiffusionPipeline # type: ignore from diffusers import DiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import save_to_safetensors
from refiners.adapters.lora import Lora, LoraAdapter
from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
from refiners.adapters.lora import Lora
def get_weight(linear: fl.Linear) -> torch.Tensor: def get_weight(linear: fl.Linear) -> torch.Tensor:
@ -69,7 +72,8 @@ def process(args: Args) -> None:
diffusers_to_refiners = converter.get_mapping() diffusers_to_refiners = converter.get_mapping()
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0) LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
for layer in refiners_model.layers(layer_type=Lora): for layer in refiners_model.layers(layer_type=Lora):
zeros_(tensor=layer.Linear_1.weight) zeros_(tensor=layer.Linear_1.weight)
@ -85,7 +89,9 @@ def process(args: Args) -> None:
p = p[seg] p = p[seg]
assert isinstance(p, fl.Chain) assert isinstance(p, fl.Chain)
last_seg = ( last_seg = (
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}" "SingleLoraAdapter"
if orig_path[-1] == "Linear"
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
) )
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"]) p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"]) p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])

View file

@ -2,9 +2,8 @@ import random
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from loguru import logger from loguru import logger
from refiners.adapters.lora import LoraAdapter, Lora
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -27,13 +26,6 @@ class LoraConfig(BaseModel):
text_encoder_targets: list[LoraTarget] text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget] lda_targets: list[LoraTarget]
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=self.rank)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
class TriggerPhraseDataset(TextEmbeddingLatentsDataset): class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__( def __init__(
@ -84,45 +76,30 @@ class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfi
class LoadLoras(Callback[LoraLatentDiffusionTrainer]): class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None: def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora lora_config = trainer.config.lora
for target in lora_config.unet_targets:
lora_config.apply_loras_to_target(module=trainer.unet, target=target) for model_name in MODELS:
for target in lora_config.text_encoder_targets: model = getattr(trainer, model_name)
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target) adapter = LoraAdapter[type(model)](
for target in lora_config.lda_targets: model,
lora_config.apply_loras_to_target(module=trainer.lda, target=target) sub_targets=getattr(lora_config, f"{model_name}_targets"),
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:
for linear in sub_adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
adapter.inject()
class SaveLoras(Callback[LoraLatentDiffusionTrainer]): class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None: def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
def get_weight(linear: fl.Linear) -> Tensor:
assert linear.bias is None
return linear.state_dict()["weight"]
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
weights: list[Tensor] = []
for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(fl.Linear))
assert len(linears) == 2
# See `load_lora_weights` in refiners.adapters.lora
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
tensors: dict[str, Tensor] = {} tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {} metadata: dict[str, str] = {}
if lora_config.unet_targets: for model_name in MODELS:
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.") model = getattr(trainer, model_name)
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)} adapter = model.parent
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
if lora_config.text_encoder_targets: metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
if lora_config.lda_targets:
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
save_to_safetensors( save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",

View file

@ -4,7 +4,7 @@ from typing import Any, Generic, TypeVar, Iterator
T = TypeVar("T", bound=fl.Module) T = TypeVar("T", bound=fl.Module)
TAdapter = TypeVar("TAdapter", bound="Adapter[fl.Module]") TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
class Adapter(Generic[T]): class Adapter(Generic[T]):
@ -36,27 +36,61 @@ class Adapter(Generic[T]):
yield yield
target._can_refresh_parent = _old_can_refresh_parent target._can_refresh_parent = _old_can_refresh_parent
def inject(self, parent: fl.Chain | None = None) -> None: def lookup_actual_target(self) -> fl.Module:
# In general, the "actual target" is the target.
# This method deals with the edge case where the target
# is part of the replacement block and has been adapted by
# another adapter after this one. For instance, this is the
# case when stacking Controlnets.
assert isinstance(self, fl.Chain) assert isinstance(self, fl.Chain)
target_parent = self.find_parent(self.target)
if (target_parent is None) or (target_parent == self):
return self.target
# Lookup and return last adapter in parents tree (or target if none).
r, p = self.target, target_parent
while p != self:
if isinstance(p, Adapter):
r = p
assert p.parent, f"parent tree of {self} is broken"
p = p.parent
return r
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
assert isinstance(self, fl.Chain)
if (parent is None) and isinstance(self.target, fl.ContextModule):
parent = self.target.parent
if parent is not None:
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
target_parent = self.find_parent(self.target)
if parent is None: if parent is None:
if isinstance(self.target, fl.ContextModule): if isinstance(self.target, fl.ContextModule):
parent = self.target.parent self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
else: return self
raise ValueError(f"parent of {self.target} is mandatory")
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
if self.target not in iter(parent): if self.target not in iter(parent):
raise ValueError(f"{self.target} is not in {parent}") raise ValueError(f"{self.target} is not in {parent}")
parent.replace( parent.replace(
old_module=self.target, old_module=self.target,
new_module=self, new_module=self,
old_module_parent=self.find_parent(self.target), old_module_parent=target_parent,
) )
return self
def eject(self) -> None: def eject(self) -> None:
assert isinstance(self, fl.Chain) assert isinstance(self, fl.Chain)
self.ensure_parent.replace(old_module=self, new_module=self.target) actual_target = self.lookup_actual_target()
if (parent := self.parent) is None:
if isinstance(actual_target, fl.ContextModule):
actual_target._set_parent(None) # type: ignore[reportPrivateUsage]
else:
parent.replace(old_module=self, new_module=actual_target)
def _pre_structural_copy(self) -> None: def _pre_structural_copy(self) -> None:
if isinstance(self.target, fl.Chain): if isinstance(self.target, fl.Chain):

View file

@ -1,8 +1,14 @@
from typing import Iterable, Generic, TypeVar, Any
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.adapters.adapter import Adapter from refiners.adapters.adapter import Adapter
from torch.nn.init import zeros_, normal_
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
from torch.nn.init import zeros_, normal_
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
class Lora(fl.Chain): class Lora(fl.Chain):
@ -37,11 +43,19 @@ class Lora(fl.Chain):
self.scale = scale self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None: def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = down_weight self.Linear_1.weight = TorchParameter(down_weight)
self.Linear_2.weight = up_weight self.Linear_2.weight = TorchParameter(up_weight)
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
class LoraAdapter(fl.Sum, Adapter[fl.Linear]): class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
structural_attrs = ["in_features", "out_features", "rank", "scale"] structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__( def __init__(
@ -67,20 +81,54 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
) )
self.Lora.set_scale(scale=scale) self.Lora.set_scale(scale=scale)
def add_lora(self, lora: Lora) -> None:
self.append(module=lora)
def load_lora_weights(self, up_weight: Tensor, down_weight: Tensor, index: int = 0) -> None: class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
self[index + 1].load_weights(up_weight=up_weight, down_weight=down_weight) def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
scale: float = 1.0,
weights: list[Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
def load_lora_weights(model: fl.Chain, weights: list[Tensor]) -> None: assert rank is not None, "either pass a rank or weights"
assert len(weights) % 2 == 0, "Number of weights must be even"
assert ( self.sub_targets = sub_targets
len(list(model.layers(layer_type=Lora))) == len(weights) // 2 self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
), "Number of Lora layers must match number of weights"
for i, lora in enumerate(iterable=model.layers(layer_type=Lora)): for linear, parent in self.sub_targets:
assert ( self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}" if weights is not None:
lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1]) assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]

View file

@ -10,7 +10,7 @@ from torch.nn import Parameter
import re import re
class ConceptExtender: class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
""" """
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique. Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
@ -37,6 +37,9 @@ class ConceptExtender:
""" """
def __init__(self, target: CLIPTextEncoder) -> None: def __init__(self, target: CLIPTextEncoder) -> None:
with self.setup_adapter(target):
super().__init__(target)
try: try:
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder)) token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
except StopIteration: except StopIteration:
@ -54,13 +57,15 @@ class ConceptExtender:
self.embedding_extender.add_embedding(embedding) self.embedding_extender.add_embedding(embedding)
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1) self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
def inject(self) -> None: def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
self.embedding_extender.inject(self.token_encoder_parent) self.embedding_extender.inject(self.token_encoder_parent)
self.token_extender.inject(self.clip_tokenizer_parent) self.token_extender.inject(self.clip_tokenizer_parent)
return super().inject(parent)
def eject(self) -> None: def eject(self) -> None:
self.embedding_extender.eject() self.embedding_extender.eject()
self.token_extender.eject() self.token_extender.eject()
super().eject()
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]): class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):

View file

@ -9,7 +9,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
StableDiffusion_1, StableDiffusion_1,
StableDiffusion_1_Inpainting, StableDiffusion_1_Inpainting,
SD1UNet, SD1UNet,
SD1Controlnet, SD1ControlnetAdapter,
) )
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLUNet, SDXLUNet,
@ -21,7 +21,7 @@ __all__ = [
"StableDiffusion_1", "StableDiffusion_1",
"StableDiffusion_1_Inpainting", "StableDiffusion_1_Inpainting",
"SD1UNet", "SD1UNet",
"SD1Controlnet", "SD1ControlnetAdapter",
"SDXLUNet", "SDXLUNet",
"DoubleTextEncoder", "DoubleTextEncoder",
"DPMSolver", "DPMSolver",

View file

@ -2,17 +2,26 @@ from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator
from torch import Tensor
from torch import Tensor, device as Device
from torch.nn import Parameter as TorchParameter
from refiners.adapters.lora import LoraAdapter, load_lora_weights
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from refiners.adapters.adapter import Adapter
from refiners.adapters.lora import SingleLoraAdapter, LoraAdapter
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion import (
StableDiffusion_1,
SD1UNet,
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
class LoraTarget(str, Enum): class LoraTarget(str, Enum):
Self = "self" Self = "self"
@ -38,66 +47,95 @@ class LoraTarget(str, Enum):
return TransformerLayer return TransformerLayer
def get_lora_rank(weights: list[Tensor]) -> int:
ranks: set[int] = {w.shape[1] for w in weights[0::2]}
assert len(ranks) == 1
return ranks.pop()
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]: def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class()) lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
for layer in it:
if isinstance(module, SD1UNet):
def predicate(m: fl.Module, p: fl.Chain) -> bool:
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, lookup_class)
else:
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(m, lookup_class)
if target == LoraTarget.Self:
for m, p in module.walk(predicate):
assert isinstance(m, fl.Linear)
yield (m, p)
return
for layer, _ in module.walk(predicate):
for t in layer.walk(fl.Linear): for t in layer.walk(fl.Linear):
yield t yield t
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None: class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
adapter.inject(parent)
class LoraWeights:
"""A single LoRA weights training checkpoint used to patch a Stable Diffusion 1.5 model."""
metadata: dict[str, str] | None metadata: dict[str, str] | None
tensors: dict[str, Tensor] tensors: dict[str, Tensor]
def __init__(self, checkpoint_path: Path | str, device: Device | str): def __init__(
self.metadata = load_metadata_from_safetensors(checkpoint_path) self,
self.tensors = load_from_safetensors(checkpoint_path, device=device) target: StableDiffusion_1,
sub_targets: dict[str, list[LoraTarget]],
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
):
with self.setup_adapter(target):
super().__init__(target)
def patch(self, sd: StableDiffusion_1, scale: float = 1.0) -> None: self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
assert self.metadata is not None, "Invalid safetensors checkpoint: missing metadata"
for meta_key, meta_value in self.metadata.items(): for model_name in MODELS:
match meta_key: if not (model_targets := sub_targets.get(model_name, [])):
case "unet_targets": continue
# TODO: support this transparently model = getattr(target, model_name)
if any([isinstance(module, SD1Controlnet) for module in sd.unet]): if model.find(SingleLoraAdapter):
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter") raise NotImplementedError(f"{model} already contains LoRA layers")
model = sd.unet
key_prefix = "unet."
case "text_encoder_targets":
model = sd.clip_text_encoder
key_prefix = "text_encoder."
case "lda_targets":
model = sd.lda
key_prefix = "lda."
case _:
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
# TODO(FG-487): support loading multiple LoRA-s lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
if any(model.layers(LoraAdapter)): self.sub_adapters.append(
raise NotImplementedError(f"{model.__class__.__name__} already contains LoRA layers") LoraAdapter[type(model)](
model,
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
scale=scale,
weights=lora_weights,
)
)
lora_weights = [w for w in [self.tensors[k] for k in sorted(self.tensors) if k.startswith(key_prefix)]] @classmethod
assert len(lora_weights) % 2 == 0 def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
rank = get_lora_rank(lora_weights) sub_targets = {}
for target in meta_value.split(","): for model_name in MODELS:
apply_loras_to_target(model, target=LoraTarget(target), rank=rank, scale=scale) if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
assert len(list(model.layers(LoraAdapter))) == (len(lora_weights) // 2) return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
)
load_lora_weights(model, [TorchParameter(w) for w in lora_weights]) def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()

View file

@ -27,10 +27,10 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
self, self,
target: SelfAttention, target: SelfAttention,
context: str, context: str,
sai: "SelfAttentionInjection", style_cfg: float = 0.5,
) -> None: ) -> None:
self.context = context self.context = context
self._sai = [sai] # only to support setting `style_cfg` dynamically self.style_cfg = style_cfg
sa_guided = target.structural_copy() sa_guided = target.structural_copy()
assert isinstance(sa_guided[0], Parallel) assert isinstance(sa_guided[0], Parallel)
@ -50,62 +50,26 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
) )
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor: def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
style_cfg = self._sai[0].style_cfg x[0] = self.style_cfg * x[0] + (1.0 - self.style_cfg) * unguided_unconditioned_x
x[0] = style_cfg * x[0] + (1.0 - style_cfg) * unguided_unconditioned_x
return x return x
class SelfAttentionInjection(Passthrough): class SelfAttentionInjectionPassthrough(Passthrough):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance def __init__(self, target: SD1UNet) -> None:
guide_unet = target.structural_copy()
def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.style_cfg = style_cfg
self._adapters: list[ReferenceOnlyControlAdapter] = []
self._unet = [unet]
guide_unet = unet.structural_copy()
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)): for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
sa = attention_block.find(SelfAttention) sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None assert sa is not None and sa.parent is not None
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject() SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
for i, attention_block in enumerate(unet.layers(CrossAttentionBlock)):
unet.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
self._adapters.append(ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", sai=self))
super().__init__( super().__init__(
Lambda(self.copy_diffusion_context), Lambda(self._copy_diffusion_context),
UseContext("self_attention_injection", "guide"), UseContext("self_attention_injection", "guide"),
guide_unet, guide_unet,
Lambda(self.restore_diffusion_context), Lambda(self._restore_diffusion_context),
) )
@property def _copy_diffusion_context(self, x: Tensor) -> Tensor:
def unet(self):
return self._unet[0]
def inject(self) -> None:
assert self not in self._unet[0], f"{self} is already injected"
for adapter in self._adapters:
adapter.inject()
self.unet.insert(0, self)
def eject(self) -> None:
assert self.unet[0] == self, f"{self} is not the first element of target UNet"
for adapter in self._adapters:
adapter.eject()
self.unet.pop(0)
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("self_attention_injection", {"guide": condition})
def copy_diffusion_context(self, x: Tensor) -> Tensor:
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used) # This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
self.set_context( self.set_context(
"self_attention_residuals_buffer", "self_attention_residuals_buffer",
@ -117,7 +81,7 @@ class SelfAttentionInjection(Passthrough):
) )
return x return x
def restore_diffusion_context(self, x: Tensor) -> Tensor: def _restore_diffusion_context(self, x: Tensor) -> Tensor:
self.set_context( self.set_context(
"unet", "unet",
{ {
@ -126,5 +90,50 @@ class SelfAttentionInjection(Passthrough):
) )
return x return x
class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.sub_adapters: list[ReferenceOnlyControlAdapter] = []
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
SelfAttentionInjectionPassthrough(target)
] # not registered by PyTorch
with self.setup_adapter(target):
super().__init__(target)
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
self.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
self.sub_adapters.append(
ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
)
def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection":
passthrough = self._passthrough[0]
assert passthrough not in self.target, f"{passthrough} is already injected"
for adapter in self.sub_adapters:
adapter.inject()
self.target.insert(0, passthrough)
return super().inject(parent)
def eject(self) -> None:
passthrough = self._passthrough[0]
assert self.target[0] == passthrough, f"{passthrough} is not the first element of target UNet"
for adapter in self.sub_adapters:
adapter.eject()
self.target.pop(0)
super().eject()
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("self_attention_injection", {"guide": condition})
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection": def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.") raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")

View file

@ -3,11 +3,11 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1, StableDiffusion_1,
StableDiffusion_1_Inpainting, StableDiffusion_1_Inpainting,
) )
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
__all__ = [ __all__ = [
"StableDiffusion_1", "StableDiffusion_1",
"StableDiffusion_1_Inpainting", "StableDiffusion_1_Inpainting",
"SD1UNet", "SD1UNet",
"SD1Controlnet", "SD1ControlnetAdapter",
] ]

View file

@ -1,11 +1,13 @@
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ( from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
SD1UNet,
DownBlocks, DownBlocks,
MiddleBlock, MiddleBlock,
ResidualBlock, ResidualBlock,
TimestepEncoder, TimestepEncoder,
) )
from refiners.adapters.adapter import Adapter
from refiners.adapters.range_adapter import RangeAdapter2d from refiners.adapters.range_adapter import RangeAdapter2d
from typing import cast, Iterable from typing import cast, Iterable
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
@ -69,10 +71,12 @@ class ConditionEncoder(Chain):
) )
class SD1Controlnet(Passthrough): class Controlnet(Passthrough):
structural_attrs = ["name", "scale"] structural_attrs = ["scale"]
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
) -> None:
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet. """Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
@ -80,8 +84,7 @@ class SD1Controlnet(Passthrough):
It has to use the same context as the UNet: `unet` and `sampling`. It has to use the same context as the UNet: `unet` and `sampling`.
""" """
self.name = name self.scale = scale
self.scale: float = 1.0
super().__init__( super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
@ -102,15 +105,14 @@ class SD1Controlnet(Passthrough):
) )
for residual_block in self.layers(ResidualBlock): for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain chain = residual_block.Chain
range_adapter = RangeAdapter2d( RangeAdapter2d(
target=chain.Conv2d_1, target=chain.Conv2d_1,
channels=residual_block.out_channels, channels=residual_block.out_channels,
embedding_dim=1280, embedding_dim=1280,
context_key=f"timestep_embedding_{self.name}", context_key=f"timestep_embedding_{name}",
device=device, device=device,
dtype=dtype, dtype=dtype,
) ).inject(chain)
range_adapter.inject(chain)
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)): for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
assert hasattr(block[0], "out_channels"), ( assert hasattr(block[0], "out_channels"), (
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`," "The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
@ -132,14 +134,6 @@ class SD1Controlnet(Passthrough):
) )
) )
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 13},
"sampling": {"shapes": []},
"controlnet": {f"condition_{self.name}": None},
"range_adapter": {f"timestep_embedding_{self.name}": None},
}
def _store_nth_residual(self, n: int): def _store_nth_residual(self, n: int):
def _store_residual(x: Tensor): def _store_residual(x: Tensor):
residuals = self.use_context("unet")["residuals"] residuals = self.use_context("unet")["residuals"]
@ -148,8 +142,39 @@ class SD1Controlnet(Passthrough):
return _store_residual return _store_residual
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
def __init__(
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
) -> None:
self.name = name
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
if weights is not None:
controlnet.load_state_dict(weights)
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
with self.setup_adapter(target):
super().__init__(target)
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
controlnet = self._controlnet[0]
assert controlnet not in self.target, f"{controlnet} is already injected"
self.target.insert(0, controlnet)
return super().inject(parent)
def eject(self) -> None:
self.target.remove(self._controlnet[0])
super().eject()
def init_context(self) -> Contexts:
return {"controlnet": {f"condition_{self.name}": None}}
def set_scale(self, scale: float) -> None:
self._controlnet[0].scale = scale
def set_controlnet_condition(self, condition: Tensor) -> None: def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("controlnet", {f"condition_{self.name}": condition}) self.set_context("controlnet", {f"condition_{self.name}": condition})
def set_scale(self, scale: float) -> None: def structural_copy(self: "SD1ControlnetAdapter") -> "SD1ControlnetAdapter":
self.scale = scale raise RuntimeError("Controlnet cannot be copied, eject it first.")

View file

@ -278,15 +278,14 @@ class SD1UNet(fl.Chain):
) )
for residual_block in self.layers(ResidualBlock): for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain chain = residual_block.Chain
range_adapter = RangeAdapter2d( RangeAdapter2d(
target=chain.Conv2d_1, target=chain.Conv2d_1,
channels=residual_block.out_channels, channels=residual_block.out_channels,
embedding_dim=1280, embedding_dim=1280,
context_key="timestep_embedding", context_key="timestep_embedding",
device=device, device=device,
dtype=dtype, dtype=dtype,
) ).inject(chain)
range_adapter.inject(chain)
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)): for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
block.append(ResidualAccumulator(n)) block.append(ResidualAccumulator(n))
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)): for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):

View file

@ -70,12 +70,11 @@ class DoubleTextEncoder(fl.Chain):
) -> None: ) -> None:
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype) text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype) text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
super().__init__( super().__init__(
fl.Parallel(text_encoder_l[:-2], text_encoder_g), fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings), fl.Lambda(func=self.concatenate_embeddings),
) )
text_encoder_with_pooling.inject(parent=self.Parallel) TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]: def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text) return super().__call__(text)

View file

@ -261,15 +261,14 @@ class SDXLUNet(fl.Chain):
) )
for residual_block in self.layers(ResidualBlock): for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain chain = residual_block.Chain
range_adapter = RangeAdapter2d( RangeAdapter2d(
target=chain.Conv2d_1, target=chain.Conv2d_1,
channels=residual_block.out_channels, channels=residual_block.out_channels,
embedding_dim=1280, embedding_dim=1280,
context_key="timestep_embedding", context_key="timestep_embedding",
device=device, device=device,
dtype=dtype, dtype=dtype,
) ).inject(chain)
range_adapter.inject(chain)
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)): for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
block.append(module=ResidualAccumulator(n=n)) block.append(module=ResidualAccumulator(n=n))
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)): for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):

View file

@ -164,8 +164,7 @@ def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
assert not ( assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout) isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer" ), f"{linear} already has a dropout layer"
adapter = DropoutAdapter(target=linear, probability=probability) DropoutAdapter(target=linear, probability=probability).inject(parent)
adapter.inject(parent)
def apply_gyro_dropout( def apply_gyro_dropout(
@ -181,14 +180,13 @@ def apply_gyro_dropout(
assert not ( assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout) isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer" ), f"{linear} already has a dropout layer"
adapter = GyroDropoutAdapter( GyroDropoutAdapter(
target=linear, target=linear,
probability=probability, probability=probability,
total_subnetworks=total_subnetworks, total_subnetworks=total_subnetworks,
concurrent_subnetworks=concurrent_subnetworks, concurrent_subnetworks=concurrent_subnetworks,
iters_per_epoch=iters_per_epoch, iters_per_epoch=iters_per_epoch,
) ).inject(parent)
adapter.inject(parent)
ConfigType = TypeVar("ConfigType", bound="BaseConfig") ConfigType = TypeVar("ConfigType", bound="BaseConfig")

View file

@ -24,9 +24,8 @@ def test_weighted_module_adapter_insertion(chain: Chain):
parent = chain.Chain parent = chain.Chain
adaptee = parent.Linear adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee) adapter = DummyLinearAdapter(adaptee).inject(parent)
adapter.inject(parent)
assert adapter.parent == parent assert adapter.parent == parent
assert adapter in iter(parent) assert adapter in iter(parent)
assert adaptee not in iter(parent) assert adaptee not in iter(parent)
@ -61,8 +60,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain):
parent = chain.Chain parent = chain.Chain
adaptee = parent.Linear adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee) DummyLinearAdapter(adaptee).inject(parent)
adapter.inject(parent)
clone = chain.structural_copy() clone = chain.structural_copy()
cloned_adapter = clone.Chain.DummyLinearAdapter cloned_adapter = clone.Chain.DummyLinearAdapter
@ -72,8 +70,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain):
def test_chain_adapter_structural_copy(chain: Chain): def test_chain_adapter_structural_copy(chain: Chain):
# Chain adapters cannot be copied by default. # Chain adapters cannot be copied by default.
adapter = DummyChainAdapter(chain.Chain) adapter = DummyChainAdapter(chain.Chain).inject()
adapter.inject()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
chain.structural_copy() chain.structural_copy()

View file

@ -1,9 +1,9 @@
from refiners.adapters.lora import Lora, LoraAdapter from refiners.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter
from torch import randn, allclose from torch import randn, allclose
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
def test_lora() -> None: def test_single_lora_adapter() -> None:
chain = fl.Chain( chain = fl.Chain(
fl.Chain( fl.Chain(
fl.Linear(in_features=1, out_features=1), fl.Linear(in_features=1, out_features=1),
@ -14,8 +14,7 @@ def test_lora() -> None:
x = randn(1, 1) x = randn(1, 1)
y = chain(x) y = chain(x)
lora_adapter = LoraAdapter(chain.Chain.Linear_1) lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain)
lora_adapter.inject(chain.Chain)
assert isinstance(lora_adapter[1], Lora) assert isinstance(lora_adapter[1], Lora)
assert allclose(input=chain(x), other=y) assert allclose(input=chain(x), other=y)
@ -26,4 +25,18 @@ def test_lora() -> None:
assert len(chain) == 2 assert len(chain) == 2
lora_adapter.inject(chain.Chain) lora_adapter.inject(chain.Chain)
assert isinstance(chain.Chain[0], LoraAdapter) assert isinstance(chain.Chain[0], SingleLoraAdapter)
def test_lora_adapter() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 3

View file

@ -15,8 +15,7 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype)) chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
adaptee = chain.RangeEncoder.Linear_1 adaptee = chain.RangeEncoder.Linear_1
adapter = DummyLinearAdapter(adaptee) adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
adapter.inject(chain.RangeEncoder)
assert adapter.parent == chain.RangeEncoder assert adapter.parent == chain.RangeEncoder

View file

@ -12,9 +12,9 @@ from refiners.foundationals.latent_diffusion import (
StableDiffusion_1, StableDiffusion_1,
StableDiffusion_1_Inpainting, StableDiffusion_1_Inpainting,
SD1UNet, SD1UNet,
SD1Controlnet, SD1ControlnetAdapter,
) )
from refiners.foundationals.latent_diffusion.lora import LoraWeights from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
@ -57,6 +57,11 @@ def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB") return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
@pytest.fixture
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"]) @pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data( def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -85,6 +90,15 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str,
return cn_name, condition_image, expected_image, weights_path return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth"
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]: def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB") expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
@ -453,11 +467,9 @@ def test_diffusion_controlnet(
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path) controlnet = SD1ControlnetAdapter(
controlnet = SD1Controlnet(name=cn_name, device=test_device) sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
controlnet.load_state_dict(controlnet_state_dict) ).inject()
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
@ -502,11 +514,9 @@ def test_diffusion_controlnet_structural_copy(
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path) controlnet = SD1ControlnetAdapter(
controlnet = SD1Controlnet(name=cn_name, device=test_device) sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
controlnet.load_state_dict(controlnet_state_dict) ).inject()
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device) cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
@ -550,11 +560,9 @@ def test_diffusion_controlnet_float16(
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path) controlnet = SD1ControlnetAdapter(
controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16) sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
controlnet.load_state_dict(controlnet_state_dict) ).inject()
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16) cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16)
@ -575,6 +583,64 @@ def test_diffusion_controlnet_float16(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_controlnet_stack(
sd15_std: StableDiffusion_1,
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
expected_image_controlnet_stack: Image.Image,
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
if not canny_cn_weights_path.is_file():
warn(f"could not find weights at {canny_cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
if not depth_cn_weights_path.is_file():
warn(f"could not find weights at {depth_cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
depth_controlnet = SD1ControlnetAdapter(
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
).inject()
canny_controlnet = SD1ControlnetAdapter(
sd15.unet, name="canny", scale=0.7, weights=load_from_safetensors(canny_cn_weights_path)
).inject()
depth_cn_condition = image_to_tensor(depth_condition_image.convert("RGB"), device=test_device)
canny_cn_condition = image_to_tensor(canny_condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
depth_controlnet.set_controlnet_condition(depth_cn_condition)
canny_controlnet.set_controlnet_condition(canny_cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @torch.no_grad()
def test_diffusion_lora( def test_diffusion_lora(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
@ -597,8 +663,7 @@ def test_diffusion_lora(
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
lora_weights = LoraWeights(lora_weights_path, device=test_device) SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
lora_weights.patch(sd15, scale=1.0)
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
@ -629,8 +694,7 @@ def test_diffusion_refonly(
with torch.no_grad(): with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet) sai = SelfAttentionInjection(sd15.unet).inject()
sai.inject()
guide = sd15.lda.encode_image(condition_image_refonly) guide = sd15.lda.encode_image(condition_image_refonly)
guide = torch.cat((guide, guide)) guide = torch.cat((guide, guide))
@ -671,29 +735,26 @@ def test_diffusion_inpainting_refonly(
with torch.no_grad(): with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet) sai = SelfAttentionInjection(sd15.unet).inject()
sai.inject()
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
refonly_guide = sd15.lda.encode_image(scene_image_inpainting_refonly) guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
refonly_guide = torch.cat((refonly_guide, refonly_guide)) guide = torch.cat((guide, guide))
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad(): with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
refonly_noise = torch.randn_like(refonly_guide) noise = torch.randn_like(guide)
refonly_noised_guide = sd15.scheduler.add_noise(refonly_guide, refonly_noise, step) noised_guide = sd15.scheduler.add_noise(guide, noise, step)
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models") # inpaint variation models")
refonly_noised_guide = torch.cat( noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
[refonly_noised_guide, torch.zeros_like(refonly_noised_guide)[:, 0:1, :, :], refonly_guide], dim=1
)
sai.set_controlnet_condition(refonly_noised_guide) sai.set_controlnet_condition(noised_guide)
x = sd15( x = sd15(
x, x,
step=step, step=step,

Binary file not shown.

After

Width:  |  Height:  |  Size: 385 KiB

View file

@ -0,0 +1,85 @@
from typing import Iterator
import torch
import pytest
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
@pytest.fixture(scope="module", params=[True, False])
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
with_parent: bool = request.param
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
if with_parent:
fl.Chain(unet)
yield unet
@torch.no_grad()
def test_single_controlnet(unet: SD1UNet) -> None:
original_parent = unet.parent
cn = SD1ControlnetAdapter(unet, name="cn")
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0
with pytest.raises(ValueError) as exc:
cn.eject()
assert "not in" in str(exc.value)
cn.inject()
assert unet.parent == cn
assert len(list(unet.walk(Controlnet))) == 1
with pytest.raises(AssertionError) as exc:
cn.inject()
assert "already injected" in str(exc.value)
cn.eject()
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
assert unet.parent == cn2
assert unet in cn2
assert unet not in cn1
assert cn2.parent == cn1
assert cn2 in cn1
assert cn1.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 2
assert cn1.target == unet
assert cn1.lookup_actual_target() == cn2
cn2.eject()
assert unet.parent == cn1
assert unet in cn2
assert cn2 not in cn1
assert unet in cn1
assert len(list(unet.walk(Controlnet))) == 1
cn1.eject()
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
cn1.eject()
assert cn2.parent == original_parent
assert unet.parent == cn2
cn2.eject()
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0

View file

@ -1,16 +0,0 @@
from refiners.adapters.lora import Lora
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
import refiners.fluxion.layers as fl
def test_lora_target_self() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
assert len(list(chain.layers(Lora))) == 3

View file

@ -0,0 +1,48 @@
import torch
import pytest
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.self_attention_injection import (
SelfAttentionInjection,
SaveLayerNormAdapter,
ReferenceOnlyControlAdapter,
SelfAttentionInjectionPassthrough,
)
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
@torch.no_grad()
def test_sai_inject_eject() -> None:
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
sai = SelfAttentionInjection(unet)
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
assert nb_cross_attention_blocks > 0
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
with pytest.raises(AssertionError) as exc:
sai.eject()
assert "not the first element" in str(exc.value)
sai.inject()
assert unet.parent == sai
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks
with pytest.raises(AssertionError) as exc:
sai.inject()
assert "already injected" in str(exc.value)
sai.eject()
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0