mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
make high-level adapters Adapters
This generalizes the Adapter abstraction to higher-level constructs such as high-level LoRA (targeting e.g. the SD UNet), ControlNet and Reference-Only Control. Some adapters now work by adapting child models with "sub-adapters" that they inject / eject when needed.
This commit is contained in:
parent
7dc2e93cff
commit
0f476ea18b
11
README.md
11
README.md
|
@ -179,12 +179,11 @@ The `Adapter` API lets you **easily patch models** by injecting parameters in ta
|
|||
E.g. to inject LoRA layers in all attention's linear layers:
|
||||
|
||||
```python
|
||||
from refiners.adapters.lora import LoraAdapter
|
||||
from refiners.adapters.lora import SingleLoraAdapter
|
||||
|
||||
for layer in vit.layers(fl.Attention):
|
||||
for linear, parent in layer.walk(fl.Linear):
|
||||
adapter = LoraAdapter(target=linear, rank=64)
|
||||
adapter.inject(parent)
|
||||
SingleLoraAdapter(target=linear, rank=64).inject(parent)
|
||||
|
||||
# ... and load existing weights if the LoRAs are pretrained ...
|
||||
```
|
||||
|
@ -232,7 +231,7 @@ Step 3: run inference using the GPU:
|
|||
|
||||
```python
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
||||
import torch
|
||||
|
||||
|
@ -242,9 +241,7 @@ sd15.clip_text_encoder.load_state_dict(load_from_safetensors("clip.safetensors")
|
|||
sd15.lda.load_state_dict(load_from_safetensors("lda.safetensors"))
|
||||
sd15.unet.load_state_dict(load_from_safetensors("unet.safetensors"))
|
||||
|
||||
# This uses the LoraAdapter internally and takes care to inject it where it should
|
||||
lora_weights = LoraWeights("pokemon_lora.safetensors", device=sd15.device)
|
||||
lora_weights.patch(sd15, scale=1.0)
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path="pokemon_lora.safetensors", scale=1.0).inject()
|
||||
|
||||
prompt = "a cute cat"
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ from refiners.fluxion.utils import save_to_safetensors
|
|||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
SD1UNet,
|
||||
SD1Controlnet,
|
||||
SD1ControlnetAdapter,
|
||||
DPMSolver,
|
||||
)
|
||||
|
||||
|
@ -21,13 +21,13 @@ class Args(argparse.Namespace):
|
|||
@torch.no_grad()
|
||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
||||
controlnet = SD1Controlnet(name="mycn")
|
||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
||||
controlnet = unet.Controlnet
|
||||
|
||||
condition = torch.randn(1, 3, 512, 512)
|
||||
controlnet.set_controlnet_condition(condition=condition)
|
||||
adapter.set_controlnet_condition(condition=condition)
|
||||
|
||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet.insert(index=0, module=controlnet)
|
||||
clip_text_embedding = torch.rand(1, 77, 768)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
|
||||
|
|
|
@ -1,17 +1,20 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn.init import zeros_
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
|
||||
from diffusers import DiffusionPipeline # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.adapters.lora import Lora, LoraAdapter
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
||||
from refiners.adapters.lora import Lora
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||
|
||||
|
||||
def get_weight(linear: fl.Linear) -> torch.Tensor:
|
||||
|
@ -69,7 +72,8 @@ def process(args: Args) -> None:
|
|||
|
||||
diffusers_to_refiners = converter.get_mapping()
|
||||
|
||||
apply_loras_to_target(module=refiners_model, target=LoraTarget(target), rank=rank, scale=1.0)
|
||||
LoraAdapter[SD1UNet](refiners_model, sub_targets=lora_targets(refiners_model, target), rank=rank).inject()
|
||||
|
||||
for layer in refiners_model.layers(layer_type=Lora):
|
||||
zeros_(tensor=layer.Linear_1.weight)
|
||||
|
||||
|
@ -85,7 +89,9 @@ def process(args: Args) -> None:
|
|||
p = p[seg]
|
||||
assert isinstance(p, fl.Chain)
|
||||
last_seg = (
|
||||
"LoraAdapter" if orig_path[-1] == "Linear" else f"LoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
||||
"SingleLoraAdapter"
|
||||
if orig_path[-1] == "Linear"
|
||||
else f"SingleLoraAdapter_{orig_path[-1].removeprefix('Linear_')}"
|
||||
)
|
||||
p_down = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.down.weight"])
|
||||
p_up = TorchParameter(data=diffusers_state_dict[f"{target_k}_lora.up.weight"])
|
||||
|
|
|
@ -2,9 +2,8 @@ import random
|
|||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from refiners.adapters.lora import LoraAdapter, Lora
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, LoraAdapter, MODELS
|
||||
import refiners.fluxion.layers as fl
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
@ -27,13 +26,6 @@ class LoraConfig(BaseModel):
|
|||
text_encoder_targets: list[LoraTarget]
|
||||
lda_targets: list[LoraTarget]
|
||||
|
||||
def apply_loras_to_target(self, module: fl.Chain, target: LoraTarget) -> None:
|
||||
for linear, parent in lora_targets(module, target):
|
||||
adapter = LoraAdapter(target=linear, rank=self.rank)
|
||||
adapter.inject(parent)
|
||||
for linear in adapter.Lora.layers(fl.Linear):
|
||||
linear.requires_grad_(requires_grad=True)
|
||||
|
||||
|
||||
class TriggerPhraseDataset(TextEmbeddingLatentsDataset):
|
||||
def __init__(
|
||||
|
@ -84,45 +76,30 @@ class LoraLatentDiffusionTrainer(LatentDiffusionTrainer[LoraLatentDiffusionConfi
|
|||
class LoadLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||
def on_train_begin(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
||||
lora_config = trainer.config.lora
|
||||
for target in lora_config.unet_targets:
|
||||
lora_config.apply_loras_to_target(module=trainer.unet, target=target)
|
||||
for target in lora_config.text_encoder_targets:
|
||||
lora_config.apply_loras_to_target(module=trainer.text_encoder, target=target)
|
||||
for target in lora_config.lda_targets:
|
||||
lora_config.apply_loras_to_target(module=trainer.lda, target=target)
|
||||
|
||||
for model_name in MODELS:
|
||||
model = getattr(trainer, model_name)
|
||||
adapter = LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=getattr(lora_config, f"{model_name}_targets"),
|
||||
rank=lora_config.rank,
|
||||
)
|
||||
for sub_adapter, _ in adapter.sub_adapters:
|
||||
for linear in sub_adapter.Lora.layers(fl.Linear):
|
||||
linear.requires_grad_(requires_grad=True)
|
||||
adapter.inject()
|
||||
|
||||
|
||||
class SaveLoras(Callback[LoraLatentDiffusionTrainer]):
|
||||
def on_checkpoint_save(self, trainer: LoraLatentDiffusionTrainer) -> None:
|
||||
lora_config = trainer.config.lora
|
||||
|
||||
def get_weight(linear: fl.Linear) -> Tensor:
|
||||
assert linear.bias is None
|
||||
return linear.state_dict()["weight"]
|
||||
|
||||
def build_loras_safetensors(module: fl.Chain, key_prefix: str) -> dict[str, Tensor]:
|
||||
weights: list[Tensor] = []
|
||||
for lora in module.layers(layer_type=Lora):
|
||||
linears = list(lora.layers(fl.Linear))
|
||||
assert len(linears) == 2
|
||||
# See `load_lora_weights` in refiners.adapters.lora
|
||||
weights.extend((get_weight(linears[1]), get_weight(linears[0]))) # aka (up_weight, down_weight)
|
||||
return {f"{key_prefix}{i:03d}": w for i, w in enumerate(weights)}
|
||||
|
||||
tensors: dict[str, Tensor] = {}
|
||||
metadata: dict[str, str] = {}
|
||||
|
||||
if lora_config.unet_targets:
|
||||
tensors |= build_loras_safetensors(trainer.unet, key_prefix="unet.")
|
||||
metadata |= {"unet_targets": ",".join(lora_config.unet_targets)}
|
||||
|
||||
if lora_config.text_encoder_targets:
|
||||
tensors |= build_loras_safetensors(trainer.text_encoder, key_prefix="text_encoder.")
|
||||
metadata |= {"text_encoder_targets": ",".join(lora_config.text_encoder_targets)}
|
||||
|
||||
if lora_config.lda_targets:
|
||||
tensors |= build_loras_safetensors(trainer.lda, key_prefix="lda.")
|
||||
metadata |= {"lda_targets": ",".join(lora_config.lda_targets)}
|
||||
for model_name in MODELS:
|
||||
model = getattr(trainer, model_name)
|
||||
adapter = model.parent
|
||||
tensors |= {f"{model_name}.{i:03d}": w for i, w in enumerate(adapter.weights)}
|
||||
metadata |= {f"{model_name}_targets": ",".join(adapter.sub_targets)}
|
||||
|
||||
save_to_safetensors(
|
||||
path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors",
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Generic, TypeVar, Iterator
|
|||
|
||||
|
||||
T = TypeVar("T", bound=fl.Module)
|
||||
TAdapter = TypeVar("TAdapter", bound="Adapter[fl.Module]")
|
||||
TAdapter = TypeVar("TAdapter", bound="Adapter[Any]") # Self (see PEP 673)
|
||||
|
||||
|
||||
class Adapter(Generic[T]):
|
||||
|
@ -36,27 +36,61 @@ class Adapter(Generic[T]):
|
|||
yield
|
||||
target._can_refresh_parent = _old_can_refresh_parent
|
||||
|
||||
def inject(self, parent: fl.Chain | None = None) -> None:
|
||||
def lookup_actual_target(self) -> fl.Module:
|
||||
# In general, the "actual target" is the target.
|
||||
# This method deals with the edge case where the target
|
||||
# is part of the replacement block and has been adapted by
|
||||
# another adapter after this one. For instance, this is the
|
||||
# case when stacking Controlnets.
|
||||
assert isinstance(self, fl.Chain)
|
||||
|
||||
target_parent = self.find_parent(self.target)
|
||||
if (target_parent is None) or (target_parent == self):
|
||||
return self.target
|
||||
|
||||
# Lookup and return last adapter in parents tree (or target if none).
|
||||
r, p = self.target, target_parent
|
||||
while p != self:
|
||||
if isinstance(p, Adapter):
|
||||
r = p
|
||||
assert p.parent, f"parent tree of {self} is broken"
|
||||
p = p.parent
|
||||
return r
|
||||
|
||||
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
||||
assert isinstance(self, fl.Chain)
|
||||
|
||||
if (parent is None) and isinstance(self.target, fl.ContextModule):
|
||||
parent = self.target.parent
|
||||
if parent is not None:
|
||||
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
|
||||
|
||||
target_parent = self.find_parent(self.target)
|
||||
|
||||
if parent is None:
|
||||
if isinstance(self.target, fl.ContextModule):
|
||||
parent = self.target.parent
|
||||
else:
|
||||
raise ValueError(f"parent of {self.target} is mandatory")
|
||||
assert isinstance(parent, fl.Chain), f"{self.target} has invalid parent {parent}"
|
||||
self.target._set_parent(target_parent) # type: ignore[reportPrivateUsage]
|
||||
return self
|
||||
|
||||
if self.target not in iter(parent):
|
||||
raise ValueError(f"{self.target} is not in {parent}")
|
||||
|
||||
parent.replace(
|
||||
old_module=self.target,
|
||||
new_module=self,
|
||||
old_module_parent=self.find_parent(self.target),
|
||||
old_module_parent=target_parent,
|
||||
)
|
||||
return self
|
||||
|
||||
def eject(self) -> None:
|
||||
assert isinstance(self, fl.Chain)
|
||||
self.ensure_parent.replace(old_module=self, new_module=self.target)
|
||||
actual_target = self.lookup_actual_target()
|
||||
|
||||
if (parent := self.parent) is None:
|
||||
if isinstance(actual_target, fl.ContextModule):
|
||||
actual_target._set_parent(None) # type: ignore[reportPrivateUsage]
|
||||
else:
|
||||
parent.replace(old_module=self, new_module=actual_target)
|
||||
|
||||
def _pre_structural_copy(self) -> None:
|
||||
if isinstance(self.target, fl.Chain):
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
from typing import Iterable, Generic, TypeVar, Any
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.adapters.adapter import Adapter
|
||||
|
||||
from torch.nn.init import zeros_, normal_
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
from torch.nn.init import zeros_, normal_
|
||||
|
||||
T = TypeVar("T", bound=fl.Chain)
|
||||
TLoraAdapter = TypeVar("TLoraAdapter", bound="LoraAdapter[Any]") # Self (see PEP 673)
|
||||
|
||||
|
||||
class Lora(fl.Chain):
|
||||
|
@ -37,11 +43,19 @@ class Lora(fl.Chain):
|
|||
self.scale = scale
|
||||
|
||||
def load_weights(self, down_weight: Tensor, up_weight: Tensor) -> None:
|
||||
self.Linear_1.weight = down_weight
|
||||
self.Linear_2.weight = up_weight
|
||||
self.Linear_1.weight = TorchParameter(down_weight)
|
||||
self.Linear_2.weight = TorchParameter(up_weight)
|
||||
|
||||
@property
|
||||
def up_weight(self) -> Tensor:
|
||||
return self.Linear_2.weight.data
|
||||
|
||||
@property
|
||||
def down_weight(self) -> Tensor:
|
||||
return self.Linear_1.weight.data
|
||||
|
||||
|
||||
class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||
class SingleLoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
||||
structural_attrs = ["in_features", "out_features", "rank", "scale"]
|
||||
|
||||
def __init__(
|
||||
|
@ -67,20 +81,54 @@ class LoraAdapter(fl.Sum, Adapter[fl.Linear]):
|
|||
)
|
||||
self.Lora.set_scale(scale=scale)
|
||||
|
||||
def add_lora(self, lora: Lora) -> None:
|
||||
self.append(module=lora)
|
||||
|
||||
def load_lora_weights(self, up_weight: Tensor, down_weight: Tensor, index: int = 0) -> None:
|
||||
self[index + 1].load_weights(up_weight=up_weight, down_weight=down_weight)
|
||||
class LoraAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
sub_targets: Iterable[tuple[fl.Linear, fl.Chain]],
|
||||
rank: int | None = None,
|
||||
scale: float = 1.0,
|
||||
weights: list[Tensor] | None = None,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
if weights is not None:
|
||||
assert len(weights) % 2 == 0
|
||||
weights_rank = weights[0].shape[1]
|
||||
if rank is None:
|
||||
rank = weights_rank
|
||||
else:
|
||||
assert rank == weights_rank
|
||||
|
||||
def load_lora_weights(model: fl.Chain, weights: list[Tensor]) -> None:
|
||||
assert len(weights) % 2 == 0, "Number of weights must be even"
|
||||
assert (
|
||||
len(list(model.layers(layer_type=Lora))) == len(weights) // 2
|
||||
), "Number of Lora layers must match number of weights"
|
||||
for i, lora in enumerate(iterable=model.layers(layer_type=Lora)):
|
||||
assert (
|
||||
lora.rank == weights[i * 2].shape[1]
|
||||
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
||||
lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
||||
assert rank is not None, "either pass a rank or weights"
|
||||
|
||||
self.sub_targets = sub_targets
|
||||
self.sub_adapters: list[tuple[SingleLoraAdapter, fl.Chain]] = []
|
||||
|
||||
for linear, parent in self.sub_targets:
|
||||
self.sub_adapters.append((SingleLoraAdapter(target=linear, rank=rank, scale=scale), parent))
|
||||
|
||||
if weights is not None:
|
||||
assert len(self.sub_adapters) == (len(weights) // 2)
|
||||
for i, (adapter, _) in enumerate(self.sub_adapters):
|
||||
lora = adapter.Lora
|
||||
assert (
|
||||
lora.rank == weights[i * 2].shape[1]
|
||||
), f"Rank of Lora layer {lora.rank} must match shape of weights {weights[i*2].shape[1]}"
|
||||
adapter.Lora.load_weights(up_weight=weights[i * 2], down_weight=weights[i * 2 + 1])
|
||||
|
||||
def inject(self: TLoraAdapter, parent: fl.Chain | None = None) -> TLoraAdapter:
|
||||
for adapter, adapter_parent in self.sub_adapters:
|
||||
adapter.inject(adapter_parent)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter, _ in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def weights(self) -> list[Tensor]:
|
||||
return [w for adapter, _ in self.sub_adapters for w in [adapter.Lora.up_weight, adapter.Lora.down_weight]]
|
||||
|
|
|
@ -10,7 +10,7 @@ from torch.nn import Parameter
|
|||
import re
|
||||
|
||||
|
||||
class ConceptExtender:
|
||||
class ConceptExtender(fl.Chain, Adapter[CLIPTextEncoder]):
|
||||
"""
|
||||
Extends the vocabulary of a CLIPTextEncoder with one or multiple new concepts, e.g. obtained via the Textual Inversion technique.
|
||||
|
||||
|
@ -37,6 +37,9 @@ class ConceptExtender:
|
|||
"""
|
||||
|
||||
def __init__(self, target: CLIPTextEncoder) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
try:
|
||||
token_encoder, self.token_encoder_parent = next(target.walk(TokenEncoder))
|
||||
except StopIteration:
|
||||
|
@ -54,13 +57,15 @@ class ConceptExtender:
|
|||
self.embedding_extender.add_embedding(embedding)
|
||||
self.token_extender.add_token(token, self.embedding_extender.num_embeddings - 1)
|
||||
|
||||
def inject(self) -> None:
|
||||
def inject(self: "ConceptExtender", parent: fl.Chain | None = None) -> "ConceptExtender":
|
||||
self.embedding_extender.inject(self.token_encoder_parent)
|
||||
self.token_extender.inject(self.clip_tokenizer_parent)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.embedding_extender.eject()
|
||||
self.token_extender.eject()
|
||||
super().eject()
|
||||
|
||||
|
||||
class EmbeddingExtender(fl.Chain, Adapter[TokenEncoder]):
|
||||
|
|
|
@ -9,7 +9,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
|||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
SD1UNet,
|
||||
SD1Controlnet,
|
||||
SD1ControlnetAdapter,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||
SDXLUNet,
|
||||
|
@ -21,7 +21,7 @@ __all__ = [
|
|||
"StableDiffusion_1",
|
||||
"StableDiffusion_1_Inpainting",
|
||||
"SD1UNet",
|
||||
"SD1Controlnet",
|
||||
"SD1ControlnetAdapter",
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"DPMSolver",
|
||||
|
|
|
@ -2,17 +2,26 @@ from enum import Enum
|
|||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from torch import Tensor, device as Device
|
||||
from torch.nn import Parameter as TorchParameter
|
||||
from refiners.adapters.lora import LoraAdapter, load_lora_weights
|
||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||
|
||||
from refiners.adapters.adapter import Adapter
|
||||
from refiners.adapters.lora import SingleLoraAdapter, LoraAdapter
|
||||
|
||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
StableDiffusion_1,
|
||||
SD1UNet,
|
||||
CLIPTextEncoderL,
|
||||
LatentDiffusionAutoencoder,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
MODELS = ["unet", "text_encoder", "lda"]
|
||||
|
||||
|
||||
class LoraTarget(str, Enum):
|
||||
Self = "self"
|
||||
|
@ -38,66 +47,95 @@ class LoraTarget(str, Enum):
|
|||
return TransformerLayer
|
||||
|
||||
|
||||
def get_lora_rank(weights: list[Tensor]) -> int:
|
||||
ranks: set[int] = {w.shape[1] for w in weights[0::2]}
|
||||
assert len(ranks) == 1
|
||||
return ranks.pop()
|
||||
|
||||
|
||||
def lora_targets(module: fl.Chain, target: LoraTarget) -> Iterator[tuple[fl.Linear, fl.Chain]]:
|
||||
it = [module] if target == LoraTarget.Self else module.layers(layer_type=target.get_class())
|
||||
for layer in it:
|
||||
lookup_class = fl.Linear if target == LoraTarget.Self else target.get_class()
|
||||
|
||||
if isinstance(module, SD1UNet):
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
if isinstance(m, Controlnet): # do not adapt Controlnet linears
|
||||
raise StopIteration
|
||||
return isinstance(m, lookup_class)
|
||||
|
||||
else:
|
||||
|
||||
def predicate(m: fl.Module, p: fl.Chain) -> bool:
|
||||
return isinstance(m, lookup_class)
|
||||
|
||||
if target == LoraTarget.Self:
|
||||
for m, p in module.walk(predicate):
|
||||
assert isinstance(m, fl.Linear)
|
||||
yield (m, p)
|
||||
return
|
||||
|
||||
for layer, _ in module.walk(predicate):
|
||||
for t in layer.walk(fl.Linear):
|
||||
yield t
|
||||
|
||||
|
||||
def apply_loras_to_target(module: fl.Chain, target: LoraTarget, rank: int, scale: float) -> None:
|
||||
for linear, parent in lora_targets(module, target):
|
||||
adapter = LoraAdapter(target=linear, rank=rank, scale=scale)
|
||||
adapter.inject(parent)
|
||||
|
||||
|
||||
class LoraWeights:
|
||||
"""A single LoRA weights training checkpoint used to patch a Stable Diffusion 1.5 model."""
|
||||
|
||||
class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||
metadata: dict[str, str] | None
|
||||
tensors: dict[str, Tensor]
|
||||
|
||||
def __init__(self, checkpoint_path: Path | str, device: Device | str):
|
||||
self.metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||
self.tensors = load_from_safetensors(checkpoint_path, device=device)
|
||||
def __init__(
|
||||
self,
|
||||
target: StableDiffusion_1,
|
||||
sub_targets: dict[str, list[LoraTarget]],
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
):
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
def patch(self, sd: StableDiffusion_1, scale: float = 1.0) -> None:
|
||||
assert self.metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||
self.sub_adapters: list[LoraAdapter[SD1UNet | CLIPTextEncoderL | LatentDiffusionAutoencoder]] = []
|
||||
|
||||
for meta_key, meta_value in self.metadata.items():
|
||||
match meta_key:
|
||||
case "unet_targets":
|
||||
# TODO: support this transparently
|
||||
if any([isinstance(module, SD1Controlnet) for module in sd.unet]):
|
||||
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
|
||||
model = sd.unet
|
||||
key_prefix = "unet."
|
||||
case "text_encoder_targets":
|
||||
model = sd.clip_text_encoder
|
||||
key_prefix = "text_encoder."
|
||||
case "lda_targets":
|
||||
model = sd.lda
|
||||
key_prefix = "lda."
|
||||
case _:
|
||||
raise ValueError(f"Unexpected key in checkpoint metadata: {meta_key}")
|
||||
for model_name in MODELS:
|
||||
if not (model_targets := sub_targets.get(model_name, [])):
|
||||
continue
|
||||
model = getattr(target, model_name)
|
||||
if model.find(SingleLoraAdapter):
|
||||
raise NotImplementedError(f"{model} already contains LoRA layers")
|
||||
|
||||
# TODO(FG-487): support loading multiple LoRA-s
|
||||
if any(model.layers(LoraAdapter)):
|
||||
raise NotImplementedError(f"{model.__class__.__name__} already contains LoRA layers")
|
||||
lora_weights = [weights[k] for k in sorted(weights) if k.startswith(model_name)] if weights else None
|
||||
self.sub_adapters.append(
|
||||
LoraAdapter[type(model)](
|
||||
model,
|
||||
sub_targets=[x for target in model_targets for x in lora_targets(model, target)],
|
||||
scale=scale,
|
||||
weights=lora_weights,
|
||||
)
|
||||
)
|
||||
|
||||
lora_weights = [w for w in [self.tensors[k] for k in sorted(self.tensors) if k.startswith(key_prefix)]]
|
||||
assert len(lora_weights) % 2 == 0
|
||||
@classmethod
|
||||
def from_safetensors(
|
||||
cls,
|
||||
target: StableDiffusion_1,
|
||||
checkpoint_path: Path | str,
|
||||
scale: float = 1.0,
|
||||
):
|
||||
metadata = load_metadata_from_safetensors(checkpoint_path)
|
||||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||
|
||||
rank = get_lora_rank(lora_weights)
|
||||
for target in meta_value.split(","):
|
||||
apply_loras_to_target(model, target=LoraTarget(target), rank=rank, scale=scale)
|
||||
sub_targets = {}
|
||||
for model_name in MODELS:
|
||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||
continue
|
||||
sub_targets[model_name] = [LoraTarget(x) for x in v.split(",")]
|
||||
|
||||
assert len(list(model.layers(LoraAdapter))) == (len(lora_weights) // 2)
|
||||
return cls(
|
||||
target,
|
||||
sub_targets,
|
||||
scale=scale,
|
||||
weights=tensors,
|
||||
)
|
||||
|
||||
load_lora_weights(model, [TorchParameter(w) for w in lora_weights])
|
||||
def inject(self: "SD1LoraAdapter", parent: fl.Chain | None = None) -> "SD1LoraAdapter":
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.inject()
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.eject()
|
||||
super().eject()
|
||||
|
|
|
@ -27,10 +27,10 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
|||
self,
|
||||
target: SelfAttention,
|
||||
context: str,
|
||||
sai: "SelfAttentionInjection",
|
||||
style_cfg: float = 0.5,
|
||||
) -> None:
|
||||
self.context = context
|
||||
self._sai = [sai] # only to support setting `style_cfg` dynamically
|
||||
self.style_cfg = style_cfg
|
||||
|
||||
sa_guided = target.structural_copy()
|
||||
assert isinstance(sa_guided[0], Parallel)
|
||||
|
@ -50,62 +50,26 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
|||
)
|
||||
|
||||
def compute_averaged_unconditioned_x(self, x: Tensor, unguided_unconditioned_x: Tensor) -> Tensor:
|
||||
style_cfg = self._sai[0].style_cfg
|
||||
x[0] = style_cfg * x[0] + (1.0 - style_cfg) * unguided_unconditioned_x
|
||||
x[0] = self.style_cfg * x[0] + (1.0 - self.style_cfg) * unguided_unconditioned_x
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttentionInjection(Passthrough):
|
||||
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||
|
||||
def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||
self.style_cfg = style_cfg
|
||||
self._adapters: list[ReferenceOnlyControlAdapter] = []
|
||||
self._unet = [unet]
|
||||
|
||||
guide_unet = unet.structural_copy()
|
||||
class SelfAttentionInjectionPassthrough(Passthrough):
|
||||
def __init__(self, target: SD1UNet) -> None:
|
||||
guide_unet = target.structural_copy()
|
||||
for i, attention_block in enumerate(guide_unet.layers(CrossAttentionBlock)):
|
||||
sa = attention_block.find(SelfAttention)
|
||||
assert sa is not None and sa.parent is not None
|
||||
SaveLayerNormAdapter(sa, context=f"self_attention_context_{i}").inject()
|
||||
|
||||
for i, attention_block in enumerate(unet.layers(CrossAttentionBlock)):
|
||||
unet.set_context(f"self_attention_context_{i}", {"norm": None})
|
||||
|
||||
sa = attention_block.find(SelfAttention)
|
||||
assert sa is not None and sa.parent is not None
|
||||
|
||||
self._adapters.append(ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", sai=self))
|
||||
|
||||
super().__init__(
|
||||
Lambda(self.copy_diffusion_context),
|
||||
Lambda(self._copy_diffusion_context),
|
||||
UseContext("self_attention_injection", "guide"),
|
||||
guide_unet,
|
||||
Lambda(self.restore_diffusion_context),
|
||||
Lambda(self._restore_diffusion_context),
|
||||
)
|
||||
|
||||
@property
|
||||
def unet(self):
|
||||
return self._unet[0]
|
||||
|
||||
def inject(self) -> None:
|
||||
assert self not in self._unet[0], f"{self} is already injected"
|
||||
for adapter in self._adapters:
|
||||
adapter.inject()
|
||||
self.unet.insert(0, self)
|
||||
|
||||
def eject(self) -> None:
|
||||
assert self.unet[0] == self, f"{self} is not the first element of target UNet"
|
||||
for adapter in self._adapters:
|
||||
adapter.eject()
|
||||
self.unet.pop(0)
|
||||
|
||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||
self.set_context("self_attention_injection", {"guide": condition})
|
||||
|
||||
def copy_diffusion_context(self, x: Tensor) -> Tensor:
|
||||
def _copy_diffusion_context(self, x: Tensor) -> Tensor:
|
||||
# This function allows to not disrupt the accumulation of residuals in the unet (if controlnet are used)
|
||||
self.set_context(
|
||||
"self_attention_residuals_buffer",
|
||||
|
@ -117,7 +81,7 @@ class SelfAttentionInjection(Passthrough):
|
|||
)
|
||||
return x
|
||||
|
||||
def restore_diffusion_context(self, x: Tensor) -> Tensor:
|
||||
def _restore_diffusion_context(self, x: Tensor) -> Tensor:
|
||||
self.set_context(
|
||||
"unet",
|
||||
{
|
||||
|
@ -126,5 +90,50 @@ class SelfAttentionInjection(Passthrough):
|
|||
)
|
||||
return x
|
||||
|
||||
|
||||
class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
||||
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||
|
||||
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||
|
||||
self.sub_adapters: list[ReferenceOnlyControlAdapter] = []
|
||||
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
|
||||
SelfAttentionInjectionPassthrough(target)
|
||||
] # not registered by PyTorch
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
for i, attention_block in enumerate(target.layers(CrossAttentionBlock)):
|
||||
self.set_context(f"self_attention_context_{i}", {"norm": None})
|
||||
|
||||
sa = attention_block.find(SelfAttention)
|
||||
assert sa is not None and sa.parent is not None
|
||||
|
||||
self.sub_adapters.append(
|
||||
ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||
)
|
||||
|
||||
def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection":
|
||||
passthrough = self._passthrough[0]
|
||||
assert passthrough not in self.target, f"{passthrough} is already injected"
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.inject()
|
||||
self.target.insert(0, passthrough)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
passthrough = self._passthrough[0]
|
||||
assert self.target[0] == passthrough, f"{passthrough} is not the first element of target UNet"
|
||||
for adapter in self.sub_adapters:
|
||||
adapter.eject()
|
||||
self.target.pop(0)
|
||||
super().eject()
|
||||
|
||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||
self.set_context("self_attention_injection", {"guide": condition})
|
||||
|
||||
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
|
||||
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")
|
||||
|
|
|
@ -3,11 +3,11 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
|||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||
|
||||
__all__ = [
|
||||
"StableDiffusion_1",
|
||||
"StableDiffusion_1_Inpainting",
|
||||
"SD1UNet",
|
||||
"SD1Controlnet",
|
||||
"SD1ControlnetAdapter",
|
||||
]
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||
SD1UNet,
|
||||
DownBlocks,
|
||||
MiddleBlock,
|
||||
ResidualBlock,
|
||||
TimestepEncoder,
|
||||
)
|
||||
from refiners.adapters.adapter import Adapter
|
||||
from refiners.adapters.range_adapter import RangeAdapter2d
|
||||
from typing import cast, Iterable
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
@ -69,10 +71,12 @@ class ConditionEncoder(Chain):
|
|||
)
|
||||
|
||||
|
||||
class SD1Controlnet(Passthrough):
|
||||
structural_attrs = ["name", "scale"]
|
||||
class Controlnet(Passthrough):
|
||||
structural_attrs = ["scale"]
|
||||
|
||||
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
def __init__(
|
||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||
) -> None:
|
||||
"""Controlnet is a Half-UNet that collects residuals from the UNet and uses them to condition the UNet.
|
||||
|
||||
Input is a `batch 3 width height` tensor, output is a `batch 1280 width//8 height//8` tensor with residuals
|
||||
|
@ -80,8 +84,7 @@ class SD1Controlnet(Passthrough):
|
|||
|
||||
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||
"""
|
||||
self.name = name
|
||||
self.scale: float = 1.0
|
||||
self.scale = scale
|
||||
super().__init__(
|
||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||
Lambda(lambda x: x.narrow(dim=1, start=0, length=4)), # support inpainting
|
||||
|
@ -102,15 +105,14 @@ class SD1Controlnet(Passthrough):
|
|||
)
|
||||
for residual_block in self.layers(ResidualBlock):
|
||||
chain = residual_block.Chain
|
||||
range_adapter = RangeAdapter2d(
|
||||
RangeAdapter2d(
|
||||
target=chain.Conv2d_1,
|
||||
channels=residual_block.out_channels,
|
||||
embedding_dim=1280,
|
||||
context_key=f"timestep_embedding_{self.name}",
|
||||
context_key=f"timestep_embedding_{name}",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
range_adapter.inject(chain)
|
||||
).inject(chain)
|
||||
for n, block in enumerate(cast(Iterable[Chain], self.DownBlocks)):
|
||||
assert hasattr(block[0], "out_channels"), (
|
||||
"The first block of every subchain in DownBlocks is expected to respond to `out_channels`,"
|
||||
|
@ -132,14 +134,6 @@ class SD1Controlnet(Passthrough):
|
|||
)
|
||||
)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {
|
||||
"unet": {"residuals": [0.0] * 13},
|
||||
"sampling": {"shapes": []},
|
||||
"controlnet": {f"condition_{self.name}": None},
|
||||
"range_adapter": {f"timestep_embedding_{self.name}": None},
|
||||
}
|
||||
|
||||
def _store_nth_residual(self, n: int):
|
||||
def _store_residual(x: Tensor):
|
||||
residuals = self.use_context("unet")["residuals"]
|
||||
|
@ -148,8 +142,39 @@ class SD1Controlnet(Passthrough):
|
|||
|
||||
return _store_residual
|
||||
|
||||
|
||||
class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||
def __init__(
|
||||
self, target: SD1UNet, name: str, scale: float = 1.0, weights: dict[str, Tensor] | None = None
|
||||
) -> None:
|
||||
self.name = name
|
||||
|
||||
controlnet = Controlnet(name=name, scale=scale, device=target.device, dtype=target.dtype)
|
||||
if weights is not None:
|
||||
controlnet.load_state_dict(weights)
|
||||
self._controlnet: list[Controlnet] = [controlnet] # not registered by PyTorch
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
|
||||
controlnet = self._controlnet[0]
|
||||
assert controlnet not in self.target, f"{controlnet} is already injected"
|
||||
self.target.insert(0, controlnet)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.target.remove(self._controlnet[0])
|
||||
super().eject()
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"controlnet": {f"condition_{self.name}": None}}
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self._controlnet[0].scale = scale
|
||||
|
||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||
self.set_context("controlnet", {f"condition_{self.name}": condition})
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self.scale = scale
|
||||
def structural_copy(self: "SD1ControlnetAdapter") -> "SD1ControlnetAdapter":
|
||||
raise RuntimeError("Controlnet cannot be copied, eject it first.")
|
||||
|
|
|
@ -278,15 +278,14 @@ class SD1UNet(fl.Chain):
|
|||
)
|
||||
for residual_block in self.layers(ResidualBlock):
|
||||
chain = residual_block.Chain
|
||||
range_adapter = RangeAdapter2d(
|
||||
RangeAdapter2d(
|
||||
target=chain.Conv2d_1,
|
||||
channels=residual_block.out_channels,
|
||||
embedding_dim=1280,
|
||||
context_key="timestep_embedding",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
range_adapter.inject(chain)
|
||||
).inject(chain)
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.DownBlocks)):
|
||||
block.append(ResidualAccumulator(n))
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.UpBlocks)):
|
||||
|
|
|
@ -70,12 +70,11 @@ class DoubleTextEncoder(fl.Chain):
|
|||
) -> None:
|
||||
text_encoder_l = text_encoder_l or CLIPTextEncoderL(device=device, dtype=dtype)
|
||||
text_encoder_g = text_encoder_g or CLIPTextEncoderG(device=device, dtype=dtype)
|
||||
text_encoder_with_pooling = TextEncoderWithPooling(target=text_encoder_g, projection=projection)
|
||||
super().__init__(
|
||||
fl.Parallel(text_encoder_l[:-2], text_encoder_g),
|
||||
fl.Lambda(func=self.concatenate_embeddings),
|
||||
)
|
||||
text_encoder_with_pooling.inject(parent=self.Parallel)
|
||||
TextEncoderWithPooling(target=text_encoder_g, projection=projection).inject(parent=self.Parallel)
|
||||
|
||||
def __call__(self, text: str) -> tuple[Float[Tensor, "1 77 2048"], Float[Tensor, "1 1280"]]:
|
||||
return super().__call__(text)
|
||||
|
|
|
@ -261,15 +261,14 @@ class SDXLUNet(fl.Chain):
|
|||
)
|
||||
for residual_block in self.layers(ResidualBlock):
|
||||
chain = residual_block.Chain
|
||||
range_adapter = RangeAdapter2d(
|
||||
RangeAdapter2d(
|
||||
target=chain.Conv2d_1,
|
||||
channels=residual_block.out_channels,
|
||||
embedding_dim=1280,
|
||||
context_key="timestep_embedding",
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
range_adapter.inject(chain)
|
||||
).inject(chain)
|
||||
for n, block in enumerate(iterable=cast(list[fl.Chain], self.DownBlocks)):
|
||||
block.append(module=ResidualAccumulator(n=n))
|
||||
for n, block in enumerate(iterable=cast(list[fl.Chain], self.UpBlocks)):
|
||||
|
|
|
@ -164,8 +164,7 @@ def apply_dropout(module: fl.Chain, probability: float = 0.5) -> None:
|
|||
assert not (
|
||||
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
||||
), f"{linear} already has a dropout layer"
|
||||
adapter = DropoutAdapter(target=linear, probability=probability)
|
||||
adapter.inject(parent)
|
||||
DropoutAdapter(target=linear, probability=probability).inject(parent)
|
||||
|
||||
|
||||
def apply_gyro_dropout(
|
||||
|
@ -181,14 +180,13 @@ def apply_gyro_dropout(
|
|||
assert not (
|
||||
isinstance(parent, Dropout) or isinstance(parent, GyroDropout)
|
||||
), f"{linear} already has a dropout layer"
|
||||
adapter = GyroDropoutAdapter(
|
||||
GyroDropoutAdapter(
|
||||
target=linear,
|
||||
probability=probability,
|
||||
total_subnetworks=total_subnetworks,
|
||||
concurrent_subnetworks=concurrent_subnetworks,
|
||||
iters_per_epoch=iters_per_epoch,
|
||||
)
|
||||
adapter.inject(parent)
|
||||
).inject(parent)
|
||||
|
||||
|
||||
ConfigType = TypeVar("ConfigType", bound="BaseConfig")
|
||||
|
|
|
@ -24,9 +24,8 @@ def test_weighted_module_adapter_insertion(chain: Chain):
|
|||
parent = chain.Chain
|
||||
adaptee = parent.Linear
|
||||
|
||||
adapter = DummyLinearAdapter(adaptee)
|
||||
adapter = DummyLinearAdapter(adaptee).inject(parent)
|
||||
|
||||
adapter.inject(parent)
|
||||
assert adapter.parent == parent
|
||||
assert adapter in iter(parent)
|
||||
assert adaptee not in iter(parent)
|
||||
|
@ -61,8 +60,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain):
|
|||
parent = chain.Chain
|
||||
adaptee = parent.Linear
|
||||
|
||||
adapter = DummyLinearAdapter(adaptee)
|
||||
adapter.inject(parent)
|
||||
DummyLinearAdapter(adaptee).inject(parent)
|
||||
|
||||
clone = chain.structural_copy()
|
||||
cloned_adapter = clone.Chain.DummyLinearAdapter
|
||||
|
@ -72,8 +70,7 @@ def test_weighted_module_adapter_structural_copy(chain: Chain):
|
|||
|
||||
def test_chain_adapter_structural_copy(chain: Chain):
|
||||
# Chain adapters cannot be copied by default.
|
||||
adapter = DummyChainAdapter(chain.Chain)
|
||||
adapter.inject()
|
||||
adapter = DummyChainAdapter(chain.Chain).inject()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
chain.structural_copy()
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from refiners.adapters.lora import Lora, LoraAdapter
|
||||
from refiners.adapters.lora import Lora, SingleLoraAdapter, LoraAdapter
|
||||
from torch import randn, allclose
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
def test_lora() -> None:
|
||||
def test_single_lora_adapter() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
|
@ -14,8 +14,7 @@ def test_lora() -> None:
|
|||
x = randn(1, 1)
|
||||
y = chain(x)
|
||||
|
||||
lora_adapter = LoraAdapter(chain.Chain.Linear_1)
|
||||
lora_adapter.inject(chain.Chain)
|
||||
lora_adapter = SingleLoraAdapter(chain.Chain.Linear_1).inject(chain.Chain)
|
||||
|
||||
assert isinstance(lora_adapter[1], Lora)
|
||||
assert allclose(input=chain(x), other=y)
|
||||
|
@ -26,4 +25,18 @@ def test_lora() -> None:
|
|||
assert len(chain) == 2
|
||||
|
||||
lora_adapter.inject(chain.Chain)
|
||||
assert isinstance(chain.Chain[0], LoraAdapter)
|
||||
assert isinstance(chain.Chain[0], SingleLoraAdapter)
|
||||
|
||||
|
||||
def test_lora_adapter() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=2),
|
||||
)
|
||||
|
||||
LoraAdapter[fl.Chain](chain, sub_targets=chain.walk(fl.Linear), rank=1, scale=1.0).inject()
|
||||
|
||||
assert len(list(chain.layers(Lora))) == 3
|
||||
|
|
|
@ -15,8 +15,7 @@ def test_range_encoder_dtype_after_adaptation(test_device: torch.device): # FG-
|
|||
chain = Chain(RangeEncoder(320, 1280, device=test_device, dtype=dtype))
|
||||
|
||||
adaptee = chain.RangeEncoder.Linear_1
|
||||
adapter = DummyLinearAdapter(adaptee)
|
||||
adapter.inject(chain.RangeEncoder)
|
||||
adapter = DummyLinearAdapter(adaptee).inject(chain.RangeEncoder)
|
||||
|
||||
assert adapter.parent == chain.RangeEncoder
|
||||
|
||||
|
|
|
@ -12,9 +12,9 @@ from refiners.foundationals.latent_diffusion import (
|
|||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
SD1UNet,
|
||||
SD1Controlnet,
|
||||
SD1ControlnetAdapter,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
|
@ -57,6 +57,11 @@ def expected_image_std_inpainting(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "expected_std_inpainting.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_controlnet_stack(ref_path: Path) -> Image.Image:
|
||||
return Image.open(ref_path / "expected_controlnet_stack.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=["canny", "depth", "lineart", "normals", "sam"])
|
||||
def controlnet_data(
|
||||
ref_path: Path, test_weights_path: Path, request: pytest.FixtureRequest
|
||||
|
@ -85,6 +90,15 @@ def controlnet_data_canny(ref_path: Path, test_weights_path: Path) -> tuple[str,
|
|||
return cn_name, condition_image, expected_image, weights_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def controlnet_data_depth(ref_path: Path, test_weights_path: Path) -> tuple[str, Image.Image, Image.Image, Path]:
|
||||
cn_name = "depth"
|
||||
condition_image = Image.open(ref_path / f"cutecat_guide_{cn_name}.png").convert("RGB")
|
||||
expected_image = Image.open(ref_path / f"expected_controlnet_{cn_name}.png").convert("RGB")
|
||||
weights_path = test_weights_path / "controlnet" / "lllyasviel_control_v11f1p_sd15_depth.safetensors"
|
||||
return cn_name, condition_image, expected_image, weights_path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_data_pokemon(ref_path: Path, test_weights_path: Path) -> tuple[Image.Image, Path]:
|
||||
expected_image = Image.open(ref_path / "expected_lora_pokemon.png").convert("RGB")
|
||||
|
@ -453,11 +467,9 @@ def test_diffusion_controlnet(
|
|||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||
controlnet = SD1Controlnet(name=cn_name, device=test_device)
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.set_scale(0.5)
|
||||
sd15.unet.insert(0, controlnet)
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
|
||||
).inject()
|
||||
|
||||
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||
|
||||
|
@ -502,11 +514,9 @@ def test_diffusion_controlnet_structural_copy(
|
|||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||
controlnet = SD1Controlnet(name=cn_name, device=test_device)
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.set_scale(0.5)
|
||||
sd15.unet.insert(0, controlnet)
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
|
||||
).inject()
|
||||
|
||||
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device)
|
||||
|
||||
|
@ -550,11 +560,9 @@ def test_diffusion_controlnet_float16(
|
|||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||
controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.set_scale(0.5)
|
||||
sd15.unet.insert(0, controlnet)
|
||||
controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name=cn_name, scale=0.5, weights=load_from_safetensors(cn_weights_path)
|
||||
).inject()
|
||||
|
||||
cn_condition = image_to_tensor(condition_image.convert("RGB"), device=test_device, dtype=torch.float16)
|
||||
|
||||
|
@ -575,6 +583,64 @@ def test_diffusion_controlnet_float16(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_controlnet_stack(
|
||||
sd15_std: StableDiffusion_1,
|
||||
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||
expected_image_controlnet_stack: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_std
|
||||
n_steps = 30
|
||||
|
||||
_, depth_condition_image, _, depth_cn_weights_path = controlnet_data_depth
|
||||
_, canny_condition_image, _, canny_cn_weights_path = controlnet_data_canny
|
||||
|
||||
if not canny_cn_weights_path.is_file():
|
||||
warn(f"could not find weights at {canny_cn_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
if not depth_cn_weights_path.is_file():
|
||||
warn(f"could not find weights at {depth_cn_weights_path}, skipping")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
depth_controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name="depth", scale=0.3, weights=load_from_safetensors(depth_cn_weights_path)
|
||||
).inject()
|
||||
canny_controlnet = SD1ControlnetAdapter(
|
||||
sd15.unet, name="canny", scale=0.7, weights=load_from_safetensors(canny_cn_weights_path)
|
||||
).inject()
|
||||
|
||||
depth_cn_condition = image_to_tensor(depth_condition_image.convert("RGB"), device=test_device)
|
||||
canny_cn_condition = image_to_tensor(canny_condition_image.convert("RGB"), device=test_device)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for step in sd15.steps:
|
||||
depth_controlnet.set_controlnet_condition(depth_cn_condition)
|
||||
canny_controlnet.set_controlnet_condition(canny_cn_condition)
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_diffusion_lora(
|
||||
sd15_std: StableDiffusion_1,
|
||||
|
@ -597,8 +663,7 @@ def test_diffusion_lora(
|
|||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
lora_weights = LoraWeights(lora_weights_path, device=test_device)
|
||||
lora_weights.patch(sd15, scale=1.0)
|
||||
SD1LoraAdapter.from_safetensors(target=sd15, checkpoint_path=lora_weights_path, scale=1.0).inject()
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
@ -629,8 +694,7 @@ def test_diffusion_refonly(
|
|||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = SelfAttentionInjection(sd15.unet)
|
||||
sai.inject()
|
||||
sai = SelfAttentionInjection(sd15.unet).inject()
|
||||
|
||||
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||
guide = torch.cat((guide, guide))
|
||||
|
@ -671,29 +735,26 @@ def test_diffusion_inpainting_refonly(
|
|||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = SelfAttentionInjection(sd15.unet)
|
||||
sai.inject()
|
||||
sai = SelfAttentionInjection(sd15.unet).inject()
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||
|
||||
refonly_guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
|
||||
refonly_guide = torch.cat((refonly_guide, refonly_guide))
|
||||
guide = sd15.lda.encode_image(scene_image_inpainting_refonly)
|
||||
guide = torch.cat((guide, guide))
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
with torch.no_grad():
|
||||
for step in sd15.steps:
|
||||
refonly_noise = torch.randn_like(refonly_guide)
|
||||
refonly_noised_guide = sd15.scheduler.add_noise(refonly_guide, refonly_noise, step)
|
||||
noise = torch.randn_like(guide)
|
||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
||||
# inpaint variation models")
|
||||
refonly_noised_guide = torch.cat(
|
||||
[refonly_noised_guide, torch.zeros_like(refonly_noised_guide)[:, 0:1, :, :], refonly_guide], dim=1
|
||||
)
|
||||
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
||||
|
||||
sai.set_controlnet_condition(refonly_noised_guide)
|
||||
sai.set_controlnet_condition(noised_guide)
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_stack.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_controlnet_stack.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 385 KiB |
85
tests/foundationals/latent_diffusion/test_controlnet.py
Normal file
85
tests/foundationals/latent_diffusion/test_controlnet.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
from typing import Iterator
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[True, False])
|
||||
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
||||
with_parent: bool = request.param
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
if with_parent:
|
||||
fl.Chain(unet)
|
||||
yield unet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_single_controlnet(unet: SD1UNet) -> None:
|
||||
original_parent = unet.parent
|
||||
cn = SD1ControlnetAdapter(unet, name="cn")
|
||||
|
||||
assert unet.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
cn.eject()
|
||||
assert "not in" in str(exc.value)
|
||||
|
||||
cn.inject()
|
||||
assert unet.parent == cn
|
||||
assert len(list(unet.walk(Controlnet))) == 1
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
cn.inject()
|
||||
assert "already injected" in str(exc.value)
|
||||
|
||||
cn.eject()
|
||||
assert unet.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
||||
original_parent = unet.parent
|
||||
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||||
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
|
||||
|
||||
assert unet.parent == cn2
|
||||
assert unet in cn2
|
||||
assert unet not in cn1
|
||||
assert cn2.parent == cn1
|
||||
assert cn2 in cn1
|
||||
assert cn1.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 2
|
||||
assert cn1.target == unet
|
||||
assert cn1.lookup_actual_target() == cn2
|
||||
|
||||
cn2.eject()
|
||||
assert unet.parent == cn1
|
||||
assert unet in cn2
|
||||
assert cn2 not in cn1
|
||||
assert unet in cn1
|
||||
assert len(list(unet.walk(Controlnet))) == 1
|
||||
|
||||
cn1.eject()
|
||||
assert unet.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
||||
original_parent = unet.parent
|
||||
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||||
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
|
||||
|
||||
cn1.eject()
|
||||
assert cn2.parent == original_parent
|
||||
assert unet.parent == cn2
|
||||
|
||||
cn2.eject()
|
||||
assert unet.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 0
|
|
@ -1,16 +0,0 @@
|
|||
from refiners.adapters.lora import Lora
|
||||
from refiners.foundationals.latent_diffusion.lora import apply_loras_to_target, LoraTarget
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
def test_lora_target_self() -> None:
|
||||
chain = fl.Chain(
|
||||
fl.Chain(
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
fl.Linear(in_features=1, out_features=1),
|
||||
),
|
||||
fl.Linear(in_features=1, out_features=2),
|
||||
)
|
||||
apply_loras_to_target(chain, LoraTarget.Self, 1, 1.0)
|
||||
|
||||
assert len(list(chain.layers(Lora))) == 3
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.self_attention_injection import (
|
||||
SelfAttentionInjection,
|
||||
SaveLayerNormAdapter,
|
||||
ReferenceOnlyControlAdapter,
|
||||
SelfAttentionInjectionPassthrough,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sai_inject_eject() -> None:
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
sai = SelfAttentionInjection(unet)
|
||||
|
||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||
assert nb_cross_attention_blocks > 0
|
||||
|
||||
assert unet.parent is None
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
sai.eject()
|
||||
assert "not the first element" in str(exc.value)
|
||||
|
||||
sai.inject()
|
||||
|
||||
assert unet.parent == sai
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
sai.inject()
|
||||
assert "already injected" in str(exc.value)
|
||||
|
||||
sai.eject()
|
||||
|
||||
assert unet.parent is None
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
|
Loading…
Reference in a new issue