mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
fix legacy wording for refonly control
This commit is contained in:
parent
0e0c39b4b5
commit
fc2390ad1c
|
@ -864,7 +864,7 @@ def test_diffusion_refonly(
|
||||||
prompt = "Chicken"
|
prompt = "Chicken"
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||||
|
|
||||||
guide = sd15.lda.encode_image(condition_image_refonly)
|
guide = sd15.lda.encode_image(condition_image_refonly)
|
||||||
guide = torch.cat((guide, guide))
|
guide = torch.cat((guide, guide))
|
||||||
|
@ -875,7 +875,7 @@ def test_diffusion_refonly(
|
||||||
for step in sd15.steps:
|
for step in sd15.steps:
|
||||||
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
||||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||||
sai.set_controlnet_condition(noised_guide)
|
refonly_adapter.set_controlnet_condition(noised_guide)
|
||||||
x = sd15(
|
x = sd15(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -903,7 +903,7 @@ def test_diffusion_inpainting_refonly(
|
||||||
prompt = "" # unconditional
|
prompt = "" # unconditional
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
|
||||||
|
@ -921,7 +921,7 @@ def test_diffusion_inpainting_refonly(
|
||||||
# inpaint variation models")
|
# inpaint variation models")
|
||||||
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
||||||
|
|
||||||
sai.set_controlnet_condition(noised_guide)
|
refonly_adapter.set_controlnet_condition(noised_guide)
|
||||||
x = sd15(
|
x = sd15(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
|
|
|
@ -13,9 +13,9 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test_sai_inject_eject() -> None:
|
def test_refonly_inject_eject() -> None:
|
||||||
unet = SD1UNet(in_channels=9)
|
unet = SD1UNet(in_channels=9)
|
||||||
sai = ReferenceOnlyControlAdapter(unet)
|
adapter = ReferenceOnlyControlAdapter(unet)
|
||||||
|
|
||||||
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
|
||||||
assert nb_cross_attention_blocks > 0
|
assert nb_cross_attention_blocks > 0
|
||||||
|
@ -26,21 +26,21 @@ def test_sai_inject_eject() -> None:
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as exc:
|
with pytest.raises(AssertionError) as exc:
|
||||||
sai.eject()
|
adapter.eject()
|
||||||
assert "not the first element" in str(exc.value)
|
assert "not the first element" in str(exc.value)
|
||||||
|
|
||||||
sai.inject()
|
adapter.inject()
|
||||||
|
|
||||||
assert unet.parent == sai
|
assert unet.parent == adapter
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
|
||||||
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
|
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
|
||||||
|
|
||||||
with pytest.raises(AssertionError) as exc:
|
with pytest.raises(AssertionError) as exc:
|
||||||
sai.inject()
|
adapter.inject()
|
||||||
assert "already injected" in str(exc.value)
|
assert "already injected" in str(exc.value)
|
||||||
|
|
||||||
sai.eject()
|
adapter.eject()
|
||||||
|
|
||||||
assert unet.parent is None
|
assert unet.parent is None
|
||||||
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0
|
||||||
|
|
Loading…
Reference in a new issue