mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
rename SelfAttentionInjection to ReferenceOnlyControl and vice-versa
This commit is contained in:
parent
eba0c33001
commit
73813310d0
|
@ -22,7 +22,7 @@ class SaveLayerNormAdapter(Chain, Adapter[SelfAttention]):
|
|||
super().__init__(SetContext(self.context, "norm"), target)
|
||||
|
||||
|
||||
class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
||||
class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
|
||||
def __init__(
|
||||
self,
|
||||
target: SelfAttention,
|
||||
|
@ -64,7 +64,7 @@ class SelfAttentionInjectionPassthrough(Passthrough):
|
|||
|
||||
super().__init__(
|
||||
Lambda(self._copy_diffusion_context),
|
||||
UseContext("self_attention_injection", "guide"),
|
||||
UseContext("reference_only_control", "guide"),
|
||||
guide_unet,
|
||||
Lambda(self._restore_diffusion_context),
|
||||
)
|
||||
|
@ -91,14 +91,14 @@ class SelfAttentionInjectionPassthrough(Passthrough):
|
|||
return x
|
||||
|
||||
|
||||
class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
||||
class ReferenceOnlyControlAdapter(Chain, Adapter[SD1UNet]):
|
||||
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||
|
||||
def __init__(self, target: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||
|
||||
self.sub_adapters: list[ReferenceOnlyControlAdapter] = []
|
||||
self.sub_adapters: list[SelfAttentionInjectionAdapter] = []
|
||||
self._passthrough: list[SelfAttentionInjectionPassthrough] = [
|
||||
SelfAttentionInjectionPassthrough(target)
|
||||
] # not registered by PyTorch
|
||||
|
@ -113,10 +113,10 @@ class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
|||
assert sa is not None and sa.parent is not None
|
||||
|
||||
self.sub_adapters.append(
|
||||
ReferenceOnlyControlAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||
SelfAttentionInjectionAdapter(sa, context=f"self_attention_context_{i}", style_cfg=style_cfg)
|
||||
)
|
||||
|
||||
def inject(self: "SelfAttentionInjection", parent: Chain | None = None) -> "SelfAttentionInjection":
|
||||
def inject(self: "ReferenceOnlyControlAdapter", parent: Chain | None = None) -> "ReferenceOnlyControlAdapter":
|
||||
passthrough = self._passthrough[0]
|
||||
assert passthrough not in self.target, f"{passthrough} is already injected"
|
||||
for adapter in self.sub_adapters:
|
||||
|
@ -133,7 +133,7 @@ class SelfAttentionInjection(Chain, Adapter[SD1UNet]):
|
|||
super().eject()
|
||||
|
||||
def set_controlnet_condition(self, condition: Tensor) -> None:
|
||||
self.set_context("self_attention_injection", {"guide": condition})
|
||||
self.set_context("reference_only_control", {"guide": condition})
|
||||
|
||||
def structural_copy(self: "SelfAttentionInjection") -> "SelfAttentionInjection":
|
||||
raise RuntimeError("SelfAttentionInjection cannot be copied, eject it first.")
|
||||
def structural_copy(self: "ReferenceOnlyControlAdapter") -> "ReferenceOnlyControlAdapter":
|
||||
raise RuntimeError("ReferenceOnlyControlAdapter cannot be copied, eject it first.")
|
|
@ -16,7 +16,7 @@ from refiners.foundationals.latent_diffusion import (
|
|||
)
|
||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
|
||||
from tests.utils import ensure_similar_images
|
||||
|
@ -694,7 +694,7 @@ def test_diffusion_refonly(
|
|||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = SelfAttentionInjection(sd15.unet).inject()
|
||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
|
||||
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||
guide = torch.cat((guide, guide))
|
||||
|
@ -735,7 +735,7 @@ def test_diffusion_inpainting_refonly(
|
|||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = SelfAttentionInjection(sd15.unet).inject()
|
||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||
|
|
|
@ -3,10 +3,10 @@ import pytest
|
|||
|
||||
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.self_attention_injection import (
|
||||
SelfAttentionInjection,
|
||||
SaveLayerNormAdapter,
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import (
|
||||
ReferenceOnlyControlAdapter,
|
||||
SaveLayerNormAdapter,
|
||||
SelfAttentionInjectionAdapter,
|
||||
SelfAttentionInjectionPassthrough,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||
|
@ -15,7 +15,7 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
|
|||
@torch.no_grad()
|
||||
def test_sai_inject_eject() -> None:
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
sai = SelfAttentionInjection(unet)
|
||||
sai = ReferenceOnlyControlAdapter(unet)
|
||||
|
||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||
assert nb_cross_attention_blocks > 0
|
||||
|
@ -23,7 +23,7 @@ def test_sai_inject_eject() -> None:
|
|||
assert unet.parent is None
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
|
||||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
sai.eject()
|
||||
|
@ -34,7 +34,7 @@ def test_sai_inject_eject() -> None:
|
|||
assert unet.parent == sai
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == nb_cross_attention_blocks
|
||||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
sai.inject()
|
||||
|
@ -45,4 +45,4 @@ def test_sai_inject_eject() -> None:
|
|||
assert unet.parent is None
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == 0
|
||||
assert len(list(unet.walk(ReferenceOnlyControlAdapter))) == 0
|
||||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
Loading…
Reference in a new issue