fix legacy wording for refonly control

This commit is contained in:
Pierre Chapuis 2023-09-14 10:40:24 +02:00
parent 0e0c39b4b5
commit fc2390ad1c
2 changed files with 11 additions and 11 deletions

View file

@ -864,7 +864,7 @@ def test_diffusion_refonly(
prompt = "Chicken" prompt = "Chicken"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = ReferenceOnlyControlAdapter(sd15.unet).inject() refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
guide = sd15.lda.encode_image(condition_image_refonly) guide = sd15.lda.encode_image(condition_image_refonly)
guide = torch.cat((guide, guide)) guide = torch.cat((guide, guide))
@ -875,7 +875,7 @@ def test_diffusion_refonly(
for step in sd15.steps: for step in sd15.steps:
noise = torch.randn(2, 4, 64, 64, device=test_device) noise = torch.randn(2, 4, 64, 64, device=test_device)
noised_guide = sd15.scheduler.add_noise(guide, noise, step) noised_guide = sd15.scheduler.add_noise(guide, noise, step)
sai.set_controlnet_condition(noised_guide) refonly_adapter.set_controlnet_condition(noised_guide)
x = sd15( x = sd15(
x, x,
step=step, step=step,
@ -903,7 +903,7 @@ def test_diffusion_inpainting_refonly(
prompt = "" # unconditional prompt = "" # unconditional
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = ReferenceOnlyControlAdapter(sd15.unet).inject() refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject()
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly)
@ -921,7 +921,7 @@ def test_diffusion_inpainting_refonly(
# inpaint variation models") # inpaint variation models")
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
sai.set_controlnet_condition(noised_guide) refonly_adapter.set_controlnet_condition(noised_guide)
x = sd15( x = sd15(
x, x,
step=step, step=step,

View file

@ -13,9 +13,9 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti
@torch.no_grad() @torch.no_grad()
def test_sai_inject_eject() -> None: def test_refonly_inject_eject() -> None:
unet = SD1UNet(in_channels=9) unet = SD1UNet(in_channels=9)
sai = ReferenceOnlyControlAdapter(unet) adapter = ReferenceOnlyControlAdapter(unet)
nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock))) nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock)))
assert nb_cross_attention_blocks > 0 assert nb_cross_attention_blocks > 0
@ -26,21 +26,21 @@ def test_sai_inject_eject() -> None:
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0 assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0
with pytest.raises(AssertionError) as exc: with pytest.raises(AssertionError) as exc:
sai.eject() adapter.eject()
assert "not the first element" in str(exc.value) assert "not the first element" in str(exc.value)
sai.inject() adapter.inject()
assert unet.parent == sai assert unet.parent == adapter
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1 assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1
assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks
assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks
with pytest.raises(AssertionError) as exc: with pytest.raises(AssertionError) as exc:
sai.inject() adapter.inject()
assert "already injected" in str(exc.value) assert "already injected" in str(exc.value)
sai.eject() adapter.eject()
assert unet.parent is None assert unet.parent is None
assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0 assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0