make high-level adapters Adapters

This generalizes the Adapter abstraction to higher-level
constructs such as high-level LoRA (targeting e.g. the
SD UNet), ControlNet and Reference-Only Control.

Some adapters now work by adapting child models with
"sub-adapters" that they inject / eject when needed.
This commit is contained in:
Pierre Chapuis 2023-08-31 10:40:01 +02:00
parent 7dc2e93cff
commit 0f476ea18b
24 changed files with 603 additions and 282 deletions

View file

@ -179,12 +179,11 @@ The `Adapter` API lets you **easily patch models** by injecting parameters in ta
E.g. to inject LoRA layers in all attention's linear layers:
```python
from refiners.adapters.lora import LoraAdapter
from refiners.adapters.lora import SingleLoraAdapter
for layer in vit.layers(fl.Attention):
for linear, parent in layer.walk(fl.Linear):
adapter = LoraAdapter(target=linear, rank=64)
adapter.inject(parent)
SingleLoraAdapter(target=linear, rank=64).inject(parent)
# ... and load existing weights if the LoRAs are pretrained ...
```
@ -232,7 +231,7 @@ Step 3: run inference using the GPU:
```python
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.fluxion.utils import load_from_safetensors, manual_seed
import torch
@ -242,9 +241,7 @@ sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors")
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
# This uses the LoraAdapter internally and takes care to inject it where it should
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
lora_weights.patch(sd15, scale=1.0)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()
prompt = "a cute cat"

View file

@ -8,7 +8,7 @@ from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion import (
SD1UNet,
SD1Controlnet,
SD1ControlnetAdapter,
DPMSolver,
)
@ -21,13 +21,13 @@ class Args(argparse.Namespace):
@torch.no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]:
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
controlnet = SD1Controlnet(name="mycn")
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
controlnet = unet.Controlnet
condition = torch.randn(1, 3, 512, 512)
controlnet.set_controlnet_condition(condition=condition)
adapter.set_controlnet_condition(condition=condition)
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
unet.insert(index=0, module=controlnet)
clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

View file

@ -1,17 +1,20 @@
import argparse
from pathlib import Path
from typing import cast
import torch
from torch import Tensor
from torch.nn.init import zeros_
from torch.nn import Parameter as TorchParameter
from diffusers import DiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.adapters.lora import Lora, LoraAdapter
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
from refiners.adapters.lora import Lora
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
def get_weight(linear: fl.Linear) -> torch.Tensor:
@ -69,7 +72,8 @@ def process(args: Args) -> None:
diffusers_to_refiners = converter.get_mapping()
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
for layer in refiners_model.layers(layer_type=Lora):
zeros_(tensor=layer.Linear_1.weight)
@ -85,7 +89,9 @@ def process(args: Args) -> None:
p = p[seg]
assert isinstance(p, fl.Chain)
last_seg = (
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
"SingleLoraAdapter"
if orig_path[-1] == "Linear"
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
)
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])

View file

@ -2,9 +2,8 @@ import random
from typing import Any
from pydantic import BaseModel
from loguru import logger
from refiners.adapters.lora import LoraAdapter, Lora
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
import refiners.fluxion.layers as fl
from torch import Tensor
from torch.utils.data import Dataset
@ -27,13 +26,6 @@ class LoraConfig(BaseModel):
text_encoder_targets: list[LoraTarget]
lda_targets: list[LoraTarget]
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=self.rank)
adapter.inject(parent)
for linear in adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
def __init__(
@ -84,45 +76,30 @@ class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfi
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
for target in lora_config.unet_targets:
lora_config.apply_loras_to_target(module=trainer.unet, target=target)
for target in lora_config.text_encoder_targets:
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target)
for target in lora_config.lda_targets:
lora_config.apply_loras_to_target(module=trainer.lda, target=target)
for model_name in MODELS:
model = getattr(trainer, model_name)
adapter = LoraAdapter[type(model)](
model,
sub_targets=getattr(lora_config, f"{model_name}_targets"),
rank=lora_config.rank,
)
for sub_adapter, _ in adapter.sub_adapters:
for linear in sub_adapter.Lora.layers(fl.Linear):
linear.requires_grad_(requires_grad=True)
adapter.inject()
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
lora_config = trainer.config.lora
def get_weight(linear: fl.Linear) -> Tensor:
assert linear.bias is None
return linear.state_dict()["weight"]
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
weights: list[Tensor] = []
for lora in module.layers(layer_type=Lora):
linears = list(lora.layers(fl.Linear))
assert len(linears) == 2
# See `load_lora_weights` in refiners.adapters.lora
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
tensors: dict[str, Tensor] = {}
metadata: dict[str, str] = {}
if lora_config.unet_targets:
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.")
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)}
if lora_config.text_encoder_targets:
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
if lora_config.lda_targets:
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
for model_name in MODELS:
model = getattr(trainer, model_name)
adapter = model.parent
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
save_to_safetensors(
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",

View file

@ -4,7 +4,7 @@ from typing import Any, Generic, TypeVar, Iterator
T = TypeVar("T", bound=fl.Module)
TAdapter = TypeVar("TAdapter", bound="Adapter[fl.Module]")
TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
class Adapter(Generic[T]):
@ -36,27 +36,61 @@ class Adapter(Generic[T]):
yield
target._can_refresh_parent = _old_can_refresh_parent
def inject(self, parent: fl.Chain | None = None) -> None:
def lookup_actual_target(self) -> fl.Module:
# In general, the "actual target" is the target.
# This method deals with the edge case where the target
# is part of the replacement block and has been adapted by
# another adapter after this one. For instance, this is the
# case when stacking Controlnets.
assert isinstance(self, fl.Chain)
target_parent = self.find_parent(self.target)
if (target_parent is None) or (target_parent == self):
return self.target
# Lookup and return last adapter in parents tree (or target if none).
r, p = self.target, target_parent
while p != self:
if isinstance(p, Adapter):
r = p
assert p.parent, f"parent tree of {self} is broken"
p = p.parent
return r
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
assert isinstance(self, fl.Chain)
if (parent is None) and isinstance(self.target, fl.ContextModule):
parent = self.target.parent
if parent is not None:
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
target_parent = self.find_parent(self.target)
if parent is None:
if isinstance(self.target, fl.ContextModule):
parent = self.target.parent
else:
raise ValueError(f"parent of {self.target} is mandatory")
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
return self
if self.target not in iter(parent):
raise ValueError(f"{self.target} is not in {parent}")
parent.replace(
old_module=self.target,
new_module=self,
old_module_parent=self.find_parent(self.target),
old_module_parent=target_parent,
)
return self
def eject(self) -> None:
assert isinstance(self, fl.Chain)
self.ensure_parent.replace(old_module=self, new_module=self.target)
actual_target = self.lookup_actual_target()
if (parent := self.parent) is None:
if isinstance(actual_target, fl.ContextModule):
actual_target._set_parent(None) # type: ignore[reportPrivateUsage]
else:
parent.replace(old_module=self, new_module=actual_target)
def _pre_structural_copy(self) -> None:
if isinstance(self.target, fl.Chain):

View file

@ -1,8 +1,14 @@
from typing import Iterable, Generic, TypeVar, Any
import refiners.fluxion.layers as fl
from refiners.adapters.adapter import Adapter
from torch.nn.init import zeros_, normal_
from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
from torch.nn.init import zeros_, normal_
T = TypeVar("T", bound=fl.Chain)
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
class Lora(fl.Chain):
@ -37,11 +43,19 @@ class Lora(fl.Chain):
self.scale = scale
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
self.Linear_1.weight = down_weight
self.Linear_2.weight = up_weight
self.Linear_1.weight = TorchParameter(down_weight)
self.Linear_2.weight = TorchParameter(up_weight)
@property
def up_weight(self) -> Tensor:
return self.Linear_2.weight.data
@property
def down_weight(self) -> Tensor:
return self.Linear_1.weight.data
class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
structural_attrs = ["in_features", "out_features", "rank", "scale"]
def __init__(
@ -67,20 +81,54 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
)
self.Lora.set_scale(scale=scale)
def add_lora(self, lora: Lora) -> None:
self.append(module=lora)
def load_lora_weights(self, up_weight: Tensor, down_weight: Tensor, index: int = 0) -> None:
self[index + 1].load_weights(up_weight=up_weight, down_weight=down_weight)
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
def __init__(
self,
target: T,
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
rank: int | None = None,
scale: float = 1.0,
weights: list[Tensor] | None = None,
) -> None:
with self.setup_adapter(target):
super().__init__(target)
if weights is not None:
assert len(weights) % 2 == 0
weights_rank = weights[0].shape[1]
if rank is None:
rank = weights_rank
else:
assert rank == weights_rank
def load_lora_weights(model: fl.Chain, weights: list[Tensor]) -> None:
assert len(weights) % 2 == 0, "Number of weights must be even"
assert (
len(list(model.layers(layer_type=Lora))) == len(weights) // 2
), "Number of Lora layers must match number of weights"
for i, lora in enumerate(iterable=model.layers(layer_type=Lora)):
assert rank is not None, "either pass a rank or weights"
self.sub_targets = sub_targets
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
for linear, parent in self.sub_targets:
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
if weights is not None:
assert len(self.sub_adapters) == (len(weights) // 2)
for i, (adapter, _) in enumerate(self.sub_adapters):
lora = adapter.Lora
assert (
lora.rank == weights[i * 2].shape[1]
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
for adapter, adapter_parent in self.sub_adapters:
adapter.inject(adapter_parent)
return super().inject(parent)
def eject(self) -> None:
for adapter, _ in self.sub_adapters:
adapter.eject()
super().eject()
@property
def weights(self) -> list[Tensor]:
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]

View file

@ -10,7 +10,7 @@ from torch.nn import Parameter
import re
class ConceptExtender:
class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
"""
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
@ -37,6 +37,9 @@ class ConceptExtender:
"""
def __init__(self, target: CLIPTextEncoder) -> None:
with self.setup_adapter(target):
super().__init__(target)
try:
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
except StopIteration:
@ -54,13 +57,15 @@ class ConceptExtender:
self.embedding_extender.add_embedding(embedding)
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
def inject(self) -> None:
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
self.embedding_extender.inject(self.token_encoder_parent)
self.token_extender.inject(self.clip_tokenizer_parent)
return super().inject(parent)
def eject(self) -> None:
self.embedding_extender.eject()
self.token_extender.eject()
super().eject()
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):

View file

@ -9,7 +9,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
SD1UNet,
SD1Controlnet,
SD1ControlnetAdapter,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLUNet,
@ -21,7 +21,7 @@ __all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1Controlnet",
"SD1ControlnetAdapter",
"SDXLUNet",
"DoubleTextEncoder",
"DPMSolver",

View file

@ -2,17 +2,26 @@ from enum import Enum
from pathlib import Path
from typing import Iterator
from torch import Tensor
from torch import Tensor, device as Device
from torch.nn import Parameter as TorchParameter
from refiners.adapters.lora import LoraAdapter, load_lora_weights
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
from refiners.adapters.adapter import Adapter
from refiners.adapters.lora import SingleLoraAdapter, LoraAdapter
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion import (
StableDiffusion_1,
SD1UNet,
CLIPTextEncoderL,
LatentDiffusionAutoencoder,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
MODELS = ["unet", "text_encoder", "lda"]
class LoraTarget(str, Enum):
Self = "self"
@ -38,66 +47,95 @@ class LoraTarget(str, Enum):
return TransformerLayer
def get_lora_rank(weights: list[Tensor]) -> int:
ranks: set[int] = {w.shape[1] for w in weights[0::2]}
assert len(ranks) == 1
return ranks.pop()
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class())
for layer in it:
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
if isinstance(module, SD1UNet):
def predicate(m: fl.Module, p: fl.Chain) -> bool:
if isinstance(m, Controlnet): # do not adapt Controlnet linears
raise StopIteration
return isinstance(m, lookup_class)
else:
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(m, lookup_class)
if target == LoraTarget.Self:
for m, p in module.walk(predicate):
assert isinstance(m, fl.Linear)
yield (m, p)
return
for layer, _ in module.walk(predicate):
for t in layer.walk(fl.Linear):
yield t
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
for linear, parent in lora_targets(module, target):
adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
adapter.inject(parent)
class LoraWeights:
"""A single LoRA weights training checkpoint used to patch a Stable Diffusion 1.5 model."""
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
metadata: dict[str, str] | None
tensors: dict[str, Tensor]
def __init__(self, checkpoint_path: Path | str, device: Device | str):
self.metadata = load_metadata_from_safetensors(checkpoint_path)
self.tensors = load_from_safetensors(checkpoint_path, device=device)
def __init__(
self,
target: StableDiffusion_1,
sub_targets: dict[str, list[LoraTarget]],
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
):
with self.setup_adapter(target):
super().__init__(target)
def patch(self, sd: StableDiffusion_1, scale: float = 1.0) -> None:
assert self.metadata is not None, "Invalid safetensors checkpoint: missing metadata"
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
for meta_key, meta_value in self.metadata.items():
match meta_key:
case "unet_targets":
# TODO: support this transparently
if any([isinstance(module, SD1Controlnet) for module in sd.unet]):
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
model = sd.unet
key_prefix = "unet."
case "text_encoder_targets":
model = sd.clip_text_encoder
key_prefix = "text_encoder."
case "lda_targets":
model = sd.lda
key_prefix = "lda."
case _:
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
for model_name in MODELS:
if not (model_targets := sub_targets.get(model_name, [])):
continue
model = getattr(target, model_name)
if model.find(SingleLoraAdapter):
raise NotImplementedError(f"{model} already contains LoRA layers")
# TODO(FG-487): support loading multiple LoRA-s
if any(model.layers(LoraAdapter)):
raise NotImplementedError(f"{model.__class__.__name__} already contains LoRA layers")
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
self.sub_adapters.append(
LoraAdapter[type(model)](
model,
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
scale=scale,
weights=lora_weights,
)
)
lora_weights = [w for w in [self.tensors[k] for k in sorted(self.tensors) if k.startswith(key_prefix)]]
assert len(lora_weights) % 2 == 0
@classmethod
def from_safetensors(
cls,
target: StableDiffusion_1,
checkpoint_path: Path | str,
scale: float = 1.0,
):
metadata = load_metadata_from_safetensors(checkpoint_path)
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
rank = get_lora_rank(lora_weights)
for target in meta_value.split(","):
apply_loras_to_target(model, target=LoraTarget(target), rank=rank, scale=scale)
sub_targets = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
assert len(list(model.layers(LoraAdapter))) == (len(lora_weights) // 2)
return cls(
target,
sub_targets,
scale=scale,
weights=tensors,
)
load_lora_weights(model, [TorchParameter(w) for w in lora_weights])
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
for adapter in self.sub_adapters:
adapter.inject()
return super().inject(parent)
def eject(self) -> None:
for adapter in self.sub_adapters:
adapter.eject()
super().eject()

View file

@ -27,10 +27,10 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
self,
target: SelfAttention,
context: str,
sai: "SelfAttentionInjection",
style_cfg: float = 0.5,
) -> None:
self.context = context
self._sai = [sai] # only to support setting `style_cfg` dynamically
self.style_cfg = style_cfg
sa_guided = target.structural_copy()
assert isinstance(sa_guided[0], Parallel)
@ -50,62 +50,26 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
)
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
style_cfg = self._sai[0].style_cfg
x[0] = style_cfg * x[0] + (1.0 - style_cfg) * unguided_unconditioned_x
x[0] = self.style_cfg * x[0] + (1.0 - self.style_cfg) * unguided_unconditioned_x
return x
class SelfAttentionInjection(Passthrough):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.style_cfg = style_cfg
self._adapters: list[ReferenceOnlyControlAdapter] = []
self._unet = [unet]
guide_unet = unet.structural_copy()
class SelfAttentionInjectionPassthrough(Passthrough):
def __init__(self, target: SD1UNet) -> None:
guide_unet = target.structural_copy()
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
for i, attention_block in enumerate(unet.layers(CrossAttentionBlock)):
unet.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
self._adapters.append(ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", sai=self))
super().__init__(
Lambda(self.copy_diffusion_context),
Lambda(self._copy_diffusion_context),
UseContext("self_attention_injection", "guide"),
guide_unet,
Lambda(self.restore_diffusion_context),
Lambda(self._restore_diffusion_context),
)
@property
def unet(self):
return self._unet[0]
def inject(self) -> None:
assert self not in self._unet[0], f"{self} is already injected"
for adapter in self._adapters:
adapter.inject()
self.unet.insert(0, self)
def eject(self) -> None:
assert self.unet[0] == self, f"{self} is not the first element of target UNet"
for adapter in self._adapters:
adapter.eject()
self.unet.pop(0)
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("self_attention_injection", {"guide": condition})
def copy_diffusion_context(self, x: Tensor) -> Tensor:
def _copy_diffusion_context(self, x: Tensor) -> Tensor:
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
self.set_context(
"self_attention_residuals_buffer",
@ -117,7 +81,7 @@ class SelfAttentionInjection(Passthrough):
)
return x
def restore_diffusion_context(self, x: Tensor) -> Tensor:
def _restore_diffusion_context(self, x: Tensor) -> Tensor:
self.set_context(
"unet",
{
@ -126,5 +90,50 @@ class SelfAttentionInjection(Passthrough):
)
return x
class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.sub_adapters: list[ReferenceOnlyControlAdapter] = []
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
SelfAttentionInjectionPassthrough(target)
] # not registered by PyTorch
with self.setup_adapter(target):
super().__init__(target)
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
self.set_context(f"self_attention_context_{i}", {"norm": None})
sa = attention_block.find(SelfAttention)
assert sa is not None and sa.parent is not None
self.sub_adapters.append(
ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
)
def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection":
passthrough = self._passthrough[0]
assert passthrough not in self.target, f"{passthrough} is already injected"
for adapter in self.sub_adapters:
adapter.inject()
self.target.insert(0, passthrough)
return super().inject(parent)
def eject(self) -> None:
passthrough = self._passthrough[0]
assert self.target[0] == passthrough, f"{passthrough} is not the first element of target UNet"
for adapter in self.sub_adapters:
adapter.eject()
self.target.pop(0)
super().eject()
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("self_attention_injection", {"guide": condition})
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")

View file

@ -3,11 +3,11 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
__all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1Controlnet",
"SD1ControlnetAdapter",
]

View file

@ -1,11 +1,13 @@
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
SD1UNet,
DownBlocks,
MiddleBlock,
ResidualBlock,
TimestepEncoder,
)
from refiners.adapters.adapter import Adapter
from refiners.adapters.range_adapter import RangeAdapter2d
from typing import cast, Iterable
from torch import Tensor, device as Device, dtype as DType
@ -69,10 +71,12 @@ class ConditionEncoder(Chain):
)
class SD1Controlnet(Passthrough):
structural_attrs = ["name", "scale"]
class Controlnet(Passthrough):
structural_attrs = ["scale"]
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
) -> None:
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
@ -80,8 +84,7 @@ class SD1Controlnet(Passthrough):
It has to use the same context as the UNet: `unet` and `sampling`.
"""
self.name = name
self.scale: float = 1.0
self.scale = scale
super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
@ -102,15 +105,14 @@ class SD1Controlnet(Passthrough):
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
range_adapter = RangeAdapter2d(
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key=f"timestep_embedding_{self.name}",
context_key=f"timestep_embedding_{name}",
device=device,
dtype=dtype,
)
range_adapter.inject(chain)
).inject(chain)
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
assert hasattr(block[0], "out_channels"), (
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
@ -132,14 +134,6 @@ class SD1Controlnet(Passthrough):
)
)
def init_context(self) -> Contexts:
return {
"unet": {"residuals": [0.0] * 13},
"sampling": {"shapes": []},
"controlnet": {f"condition_{self.name}": None},
"range_adapter": {f"timestep_embedding_{self.name}": None},
}
def _store_nth_residual(self, n: int):
def _store_residual(x: Tensor):
residuals = self.use_context("unet")["residuals"]
@ -148,8 +142,39 @@ class SD1Controlnet(Passthrough):
return _store_residual
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
def __init__(
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
) -> None:
self.name = name
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
if weights is not None:
controlnet.load_state_dict(weights)
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
with self.setup_adapter(target):
super().__init__(target)
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
controlnet = self._controlnet[0]
assert controlnet not in self.target, f"{controlnet} is already injected"
self.target.insert(0, controlnet)
return super().inject(parent)
def eject(self) -> None:
self.target.remove(self._controlnet[0])
super().eject()
def init_context(self) -> Contexts:
return {"controlnet": {f"condition_{self.name}": None}}
def set_scale(self, scale: float) -> None:
self._controlnet[0].scale = scale
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("controlnet", {f"condition_{self.name}": condition})
def set_scale(self, scale: float) -> None:
self.scale = scale
def structural_copy(self: "SD1ControlnetAdapter") -> "SD1ControlnetAdapter":
raise RuntimeError("Controlnet cannot be copied, eject it first.")

View file

@ -278,15 +278,14 @@ class SD1UNet(fl.Chain):
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
range_adapter = RangeAdapter2d(
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
)
range_adapter.inject(chain)
).inject(chain)
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
block.append(ResidualAccumulator(n))
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):

View file

@ -70,12 +70,11 @@ class DoubleTextEncoder(fl.Chain):
) -> None:
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
super().__init__(
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
fl.Lambda(func=self.concatenate_embeddings),
)
text_encoder_with_pooling.inject(parent=self.Parallel)
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
return super().__call__(text)

View file

@ -261,15 +261,14 @@ class SDXLUNet(fl.Chain):
)
for residual_block in self.layers(ResidualBlock):
chain = residual_block.Chain
range_adapter = RangeAdapter2d(
RangeAdapter2d(
target=chain.Conv2d_1,
channels=residual_block.out_channels,
embedding_dim=1280,
context_key="timestep_embedding",
device=device,
dtype=dtype,
)
range_adapter.inject(chain)
).inject(chain)
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
block.append(module=ResidualAccumulator(n=n))
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):

View file

@ -164,8 +164,7 @@ def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer"
adapter = DropoutAdapter(target=linear, probability=probability)
adapter.inject(parent)
DropoutAdapter(target=linear, probability=probability).inject(parent)
def apply_gyro_dropout(
@ -181,14 +180,13 @@ def apply_gyro_dropout(
assert not (
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
), f"{linear} already has a dropout layer"
adapter = GyroDropoutAdapter(
GyroDropoutAdapter(
target=linear,
probability=probability,
total_subnetworks=total_subnetworks,
concurrent_subnetworks=concurrent_subnetworks,
iters_per_epoch=iters_per_epoch,
)
adapter.inject(parent)
).inject(parent)
ConfigType = TypeVar("ConfigType", bound="BaseConfig")

View file

@ -24,9 +24,8 @@ def test_weighted_module_adapter_insertion(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee)
adapter = DummyLinearAdapter(adaptee).inject(parent)
adapter.inject(parent)
assert adapter.parent == parent
assert adapter in iter(parent)
assert adaptee not in iter(parent)
@ -61,8 +60,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain):
parent = chain.Chain
adaptee = parent.Linear
adapter = DummyLinearAdapter(adaptee)
adapter.inject(parent)
DummyLinearAdapter(adaptee).inject(parent)
clone = chain.structural_copy()
cloned_adapter = clone.Chain.DummyLinearAdapter
@ -72,8 +70,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain):
def test_chain_adapter_structural_copy(chain: Chain):
# Chain adapters cannot be copied by default.
adapter = DummyChainAdapter(chain.Chain)
adapter.inject()
adapter = DummyChainAdapter(chain.Chain).inject()
with pytest.raises(RuntimeError):
chain.structural_copy()

View file

@ -1,9 +1,9 @@
from refiners.adapters.lora import Lora, LoraAdapter
from refiners.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter
from torch import randn, allclose
import refiners.fluxion.layers as fl
def test_lora() -> None:
def test_single_lora_adapter() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
@ -14,8 +14,7 @@ def test_lora() -> None:
x = randn(1, 1)
y = chain(x)
lora_adapter = LoraAdapter(chain.Chain.Linear_1)
lora_adapter.inject(chain.Chain)
lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain)
assert isinstance(lora_adapter[1], Lora)
assert allclose(input=chain(x), other=y)
@ -26,4 +25,18 @@ def test_lora() -> None:
assert len(chain) == 2
lora_adapter.inject(chain.Chain)
assert isinstance(chain.Chain[0], LoraAdapter)
assert isinstance(chain.Chain[0], SingleLoraAdapter)
def test_lora_adapter() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
assert len(list(chain.layers(Lora))) == 3

View file

@ -15,8 +15,7 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
adaptee = chain.RangeEncoder.Linear_1
adapter = DummyLinearAdapter(adaptee)
adapter.inject(chain.RangeEncoder)
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
assert adapter.parent == chain.RangeEncoder

View file

@ -12,9 +12,9 @@ from refiners.foundationals.latent_diffusion import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
SD1UNet,
SD1Controlnet,
SD1ControlnetAdapter,
)
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
from refiners.foundationals.clip.concepts import ConceptExtender
@ -57,6 +57,11 @@ def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
@pytest.fixture
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
def controlnet_data(
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
@ -85,6 +90,15 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str,
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
cn_name = "depth"
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
return cn_name, condition_image, expected_image, weights_path
@pytest.fixture(scope="module")
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
@ -453,11 +467,9 @@ def test_diffusion_controlnet(
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = SD1Controlnet(name=cn_name, device=test_device)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
@ -502,11 +514,9 @@ def test_diffusion_controlnet_structural_copy(
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = SD1Controlnet(name=cn_name, device=test_device)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
@ -550,11 +560,9 @@ def test_diffusion_controlnet_float16(
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
controlnet = SD1ControlnetAdapter(
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
).inject()
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16)
@ -575,6 +583,64 @@ def test_diffusion_controlnet_float16(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_controlnet_stack(
sd15_std: StableDiffusion_1,
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
expected_image_controlnet_stack: Image.Image,
test_device: torch.device,
):
sd15 = sd15_std
n_steps = 30
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
if not canny_cn_weights_path.is_file():
warn(f"could not find weights at {canny_cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
if not depth_cn_weights_path.is_file():
warn(f"could not find weights at {depth_cn_weights_path}, skipping")
pytest.skip(allow_module_level=True)
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
depth_controlnet = SD1ControlnetAdapter(
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
).inject()
canny_controlnet = SD1ControlnetAdapter(
sd15.unet, name="canny", scale=0.7, weights=load_from_safetensors(canny_cn_weights_path)
).inject()
depth_cn_condition = image_to_tensor(depth_condition_image.convert("RGB"), device=test_device)
canny_cn_condition = image_to_tensor(canny_condition_image.convert("RGB"), device=test_device)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
depth_controlnet.set_controlnet_condition(depth_cn_condition)
canny_controlnet.set_controlnet_condition(canny_cn_condition)
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
def test_diffusion_lora(
sd15_std: StableDiffusion_1,
@ -597,8 +663,7 @@ def test_diffusion_lora(
sd15.set_num_inference_steps(n_steps)
lora_weights = LoraWeights(lora_weights_path, device=test_device)
lora_weights.patch(sd15, scale=1.0)
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
@ -629,8 +694,7 @@ def test_diffusion_refonly(
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet)
sai.inject()
sai = SelfAttentionInjection(sd15.unet).inject()
guide = sd15.lda.encode_image(condition_image_refonly)
guide = torch.cat((guide, guide))
@ -671,29 +735,26 @@ def test_diffusion_inpainting_refonly(
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet)
sai.inject()
sai = SelfAttentionInjection(sd15.unet).inject()
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
refonly_guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
refonly_guide = torch.cat((refonly_guide, refonly_guide))
guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
guide = torch.cat((guide, guide))
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps:
refonly_noise = torch.randn_like(refonly_guide)
refonly_noised_guide = sd15.scheduler.add_noise(refonly_guide, refonly_noise, step)
noise = torch.randn_like(guide)
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
# inpaint variation models")
refonly_noised_guide = torch.cat(
[refonly_noised_guide, torch.zeros_like(refonly_noised_guide)[:, 0:1, :, :], refonly_guide], dim=1
)
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
sai.set_controlnet_condition(refonly_noised_guide)
sai.set_controlnet_condition(noised_guide)
x = sd15(
x,
step=step,

Binary file not shown.

After

Width:  |  Height:  |  Size: 385 KiB

View file

@ -0,0 +1,85 @@
from typing import Iterator
import torch
import pytest
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
@pytest.fixture(scope="module", params=[True, False])
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
with_parent: bool = request.param
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
if with_parent:
fl.Chain(unet)
yield unet
@torch.no_grad()
def test_single_controlnet(unet: SD1UNet) -> None:
original_parent = unet.parent
cn = SD1ControlnetAdapter(unet, name="cn")
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0
with pytest.raises(ValueError) as exc:
cn.eject()
assert "not in" in str(exc.value)
cn.inject()
assert unet.parent == cn
assert len(list(unet.walk(Controlnet))) == 1
with pytest.raises(AssertionError) as exc:
cn.inject()
assert "already injected" in str(exc.value)
cn.eject()
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
assert unet.parent == cn2
assert unet in cn2
assert unet not in cn1
assert cn2.parent == cn1
assert cn2 in cn1
assert cn1.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 2
assert cn1.target == unet
assert cn1.lookup_actual_target() == cn2
cn2.eject()
assert unet.parent == cn1
assert unet in cn2
assert cn2 not in cn1
assert unet in cn1
assert len(list(unet.walk(Controlnet))) == 1
cn1.eject()
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
cn1.eject()
assert cn2.parent == original_parent
assert unet.parent == cn2
cn2.eject()
assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0

View file

@ -1,16 +0,0 @@
from refiners.adapters.lora import Lora
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
import refiners.fluxion.layers as fl
def test_lora_target_self() -> None:
chain = fl.Chain(
fl.Chain(
fl.Linear(in_features=1, out_features=1),
fl.Linear(in_features=1, out_features=1),
),
fl.Linear(in_features=1, out_features=2),
)
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
assert len(list(chain.layers(Lora))) == 3

View file

@ -0,0 +1,48 @@
import torch
import pytest
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.self_attention_injection import (
SelfAttentionInjection,
SaveLayerNormAdapter,
ReferenceOnlyControlAdapter,
SelfAttentionInjectionPassthrough,
)
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
@torch.no_grad()
def test_sai_inject_eject() -> None:
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
sai = SelfAttentionInjection(unet)
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
assert nb_cross_attention_blocks > 0
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
with pytest.raises(AssertionError) as exc:
sai.eject()
assert "not the first element" in str(exc.value)
sai.inject()
assert unet.parent == sai
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks
with pytest.raises(AssertionError) as exc:
sai.inject()
assert "already injected" in str(exc.value)
sai.eject()
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0