rename SelfAttentionInjection to ReferenceOnlyControl and vice-versa

This commit is contained in:
Pierre Chapuis 2023-09-01 12:02:47 +02:00
parent eba0c33001
commit 73813310d0
3 changed files with 19 additions and 19 deletions

View file

@ -22,7 +22,7 @@ class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
super().__init__(SetContext(self.context, "norm"), target)
class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
def __init__(
self,
target: SelfAttention,
@ -64,7 +64,7 @@ class SelfAttentionInjectionPassthrough(Passthrough):
super().__init__(
Lambda(self._copy_diffusion_context),
UseContext("self_attention_injection", "guide"),
UseContext("reference_only_control", "guide"),
guide_unet,
Lambda(self._restore_diffusion_context),
)
@ -91,14 +91,14 @@ class SelfAttentionInjectionPassthrough(Passthrough):
return x
class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.sub_adapters: list[ReferenceOnlyControlAdapter] = []
self.sub_adapters: list[SelfAttentionInjectionAdapter] = []
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
SelfAttentionInjectionPassthrough(target)
] # not registered by PyTorch
@ -113,10 +113,10 @@ class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
assert sa is not None and sa.parent is not None
self.sub_adapters.append(
ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
)
def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection":
def inject(self: "ReferenceOnlyControlAdapter", parent: Chain | None = None) -> "ReferenceOnlyControlAdapter":
passthrough = self._passthrough[0]
assert passthrough not in self.target, f"{passthrough} is already injected"
for adapter in self.sub_adapters:
@ -133,7 +133,7 @@ class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
super().eject()
def set_controlnet_condition(self, condition: Tensor) -> None:
self.set_context("self_attention_injection", {"guide": condition})
self.set_context("reference_only_control", {"guide": condition})
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")
def structural_copy(self: "ReferenceOnlyControlAdapter") -> "ReferenceOnlyControlAdapter":
raise RuntimeError("ReferenceOnlyControlAdapter cannot be copied, eject it first.")

View file

@ -16,7 +16,7 @@ from refiners.foundationals.latent_diffusion import (
)
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.clip.concepts import ConceptExtender
from tests.utils import ensure_similar_images
@ -694,7 +694,7 @@ def test_diffusion_refonly(
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet).inject()
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
guide = sd15.lda.encode_image(condition_image_refonly)
guide = torch.cat((guide, guide))
@ -735,7 +735,7 @@ def test_diffusion_inpainting_refonly(
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet).inject()
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)

View file

@ -3,10 +3,10 @@ import pytest
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.self_attention_injection import (
SelfAttentionInjection,
SaveLayerNormAdapter,
from refiners.foundationals.latent_diffusion.reference_only_control import (
ReferenceOnlyControlAdapter,
SaveLayerNormAdapter,
SelfAttentionInjectionAdapter,
SelfAttentionInjectionPassthrough,
)
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
@ -15,7 +15,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
@torch.no_grad()
def test_sai_inject_eject() -> None:
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
sai = SelfAttentionInjection(unet)
sai = ReferenceOnlyControlAdapter(unet)
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
assert nb_cross_attention_blocks > 0
@ -23,7 +23,7 @@ def test_sai_inject_eject() -> None:
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
with pytest.raises(AssertionError) as exc:
sai.eject()
@ -34,7 +34,7 @@ def test_sai_inject_eject() -> None:
assert unet.parent == sai
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
with pytest.raises(AssertionError) as exc:
sai.inject()
@ -45,4 +45,4 @@ def test_sai_inject_eject() -> None:
assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0