mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
fix legacy wording for refonly control
This commit is contained in:
parent
0e0c39b4b5
commit
fc2390ad1c
|
@ -864,7 +864,7 @@ def test_diffusion_refonly(
|
|||
prompt = "Chicken"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
|
||||
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||
guide = torch.cat((guide, guide))
|
||||
|
@ -875,7 +875,7 @@ def test_diffusion_refonly(
|
|||
for step in sd15.steps:
|
||||
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||
sai.set_controlnet_condition(noised_guide)
|
||||
refonly_adapter.set_controlnet_condition(noised_guide)
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
|
@ -903,7 +903,7 @@ def test_diffusion_inpainting_refonly(
|
|||
prompt = "" # unconditional
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||
|
@ -921,7 +921,7 @@ def test_diffusion_inpainting_refonly(
|
|||
# inpaint variation models")
|
||||
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
||||
|
||||
sai.set_controlnet_condition(noised_guide)
|
||||
refonly_adapter.set_controlnet_condition(noised_guide)
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
|
|
|
@ -13,9 +13,9 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sai_inject_eject() -> None:
|
||||
def test_refonly_inject_eject() -> None:
|
||||
unet = SD1UNet(in_channels=9)
|
||||
sai = ReferenceOnlyControlAdapter(unet)
|
||||
adapter = ReferenceOnlyControlAdapter(unet)
|
||||
|
||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||
assert nb_cross_attention_blocks > 0
|
||||
|
@ -26,21 +26,21 @@ def test_sai_inject_eject() -> None:
|
|||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
sai.eject()
|
||||
adapter.eject()
|
||||
assert "not the first element" in str(exc.value)
|
||||
|
||||
sai.inject()
|
||||
adapter.inject()
|
||||
|
||||
assert unet.parent == sai
|
||||
assert unet.parent == adapter
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
||||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
sai.inject()
|
||||
adapter.inject()
|
||||
assert "already injected" in str(exc.value)
|
||||
|
||||
sai.eject()
|
||||
adapter.eject()
|
||||
|
||||
assert unet.parent is None
|
||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||
|
|
Loading…
Reference in a new issue