mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
rename SelfAttentionInjection to ReferenceOnlyControl and vice-versa
This commit is contained in:
parent
eba0c33001
commit
73813310d0
|
@ -22,7 +22,7 @@ class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
|
||||||
super().__init__(SetContext(self.context, "norm"), target)
|
super().__init__(SetContext(self.context, "norm"), target)
|
||||||
|
|
||||||
|
|
||||||
class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: SelfAttention,
|
target: SelfAttention,
|
||||||
|
@ -64,7 +64,7 @@ class SelfAttentionInjectionPassthrough(Passthrough):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Lambda(self._copy_diffusion_context),
|
Lambda(self._copy_diffusion_context),
|
||||||
UseContext("self_attention_injection", "guide"),
|
UseContext("reference_only_control", "guide"),
|
||||||
guide_unet,
|
guide_unet,
|
||||||
Lambda(self._restore_diffusion_context),
|
Lambda(self._restore_diffusion_context),
|
||||||
)
|
)
|
||||||
|
@ -91,14 +91,14 @@ class SelfAttentionInjectionPassthrough(Passthrough):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
|
||||||
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||||
|
|
||||||
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
|
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||||
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||||
# This value is recommended to be 0.5 on the sdwebui repo.
|
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||||
|
|
||||||
self.sub_adapters: list[ReferenceOnlyControlAdapter] = []
|
self.sub_adapters: list[SelfAttentionInjectionAdapter] = []
|
||||||
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
|
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
|
||||||
SelfAttentionInjectionPassthrough(target)
|
SelfAttentionInjectionPassthrough(target)
|
||||||
] # not registered by PyTorch
|
] # not registered by PyTorch
|
||||||
|
@ -113,10 +113,10 @@ class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
||||||
assert sa is not None and sa.parent is not None
|
assert sa is not None and sa.parent is not None
|
||||||
|
|
||||||
self.sub_adapters.append(
|
self.sub_adapters.append(
|
||||||
ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||||
)
|
)
|
||||||
|
|
||||||
def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection":
|
def inject(self: "ReferenceOnlyControlAdapter", parent: Chain | None = None) -> "ReferenceOnlyControlAdapter":
|
||||||
passthrough = self._passthrough[0]
|
passthrough = self._passthrough[0]
|
||||||
assert passthrough not in self.target, f"{passthrough} is already injected"
|
assert passthrough not in self.target, f"{passthrough} is already injected"
|
||||||
for adapter in self.sub_adapters:
|
for adapter in self.sub_adapters:
|
||||||
|
@ -133,7 +133,7 @@ class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
||||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||||
self.set_context("self_attention_injection", {"guide": condition})
|
self.set_context("reference_only_control", {"guide": condition})
|
||||||
|
|
||||||
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
|
def structural_copy(self: "ReferenceOnlyControlAdapter") -> "ReferenceOnlyControlAdapter":
|
||||||
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")
|
raise RuntimeError("ReferenceOnlyControlAdapter cannot be copied, eject it first.")
|
|
@ -16,7 +16,7 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
@ -694,7 +694,7 @@ def test_diffusion_refonly(
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = SelfAttentionInjection(sd15.unet).inject()
|
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||||
|
|
||||||
guide = sd15.lda.encode_image(condition_image_refonly)
|
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||||
guide = torch.cat((guide, guide))
|
guide = torch.cat((guide, guide))
|
||||||
|
@ -735,7 +735,7 @@ def test_diffusion_inpainting_refonly(
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = SelfAttentionInjection(sd15.unet).inject()
|
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||||
|
|
|
@ -3,10 +3,10 @@ import pytest
|
||||||
|
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.self_attention_injection import (
|
from refiners.foundationals.latent_diffusion.reference_only_control import (
|
||||||
SelfAttentionInjection,
|
|
||||||
SaveLayerNormAdapter,
|
|
||||||
ReferenceOnlyControlAdapter,
|
ReferenceOnlyControlAdapter,
|
||||||
|
SaveLayerNormAdapter,
|
||||||
|
SelfAttentionInjectionAdapter,
|
||||||
SelfAttentionInjectionPassthrough,
|
SelfAttentionInjectionPassthrough,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||||
|
@ -15,7 +15,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_sai_inject_eject() -> None:
|
def test_sai_inject_eject() -> None:
|
||||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
sai = SelfAttentionInjection(unet)
|
sai = ReferenceOnlyControlAdapter(unet)
|
||||||
|
|
||||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||||
assert nb_cross_attention_blocks > 0
|
assert nb_cross_attention_blocks > 0
|
||||||
|
@ -23,7 +23,7 @@ def test_sai_inject_eject() -> None:
|
||||||
assert unet.parent is None
|
assert unet.parent is None
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
||||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as exc:
|
with pytest.raises(AssertionError) as exc:
|
||||||
sai.eject()
|
sai.eject()
|
||||||
|
@ -34,7 +34,7 @@ def test_sai_inject_eject() -> None:
|
||||||
assert unet.parent == sai
|
assert unet.parent == sai
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
||||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
||||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as exc:
|
with pytest.raises(AssertionError) as exc:
|
||||||
sai.inject()
|
sai.inject()
|
||||||
|
@ -45,4 +45,4 @@ def test_sai_inject_eject() -> None:
|
||||||
assert unet.parent is None
|
assert unet.parent is None
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
||||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
Loading…
Reference in a new issue