From fc2390ad1c0820d47a8c1eea5a8443cdab1062dc Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Thu, 14 Sep 2023 10:40:24 +0200 Subject: [PATCH] fix legacy wording for refonly control --- tests/e2e/test_diffusion.py | 8 ++++---- .../test_reference_only_control.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 1fb78f3..d6528c6 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -864,7 +864,7 @@ def test_diffusion_refonly( prompt = "Chicken" clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - sai = ReferenceOnlyControlAdapter(sd15.unet).inject() + refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() guide = sd15.lda.encode_image(condition_image_refonly) guide = torch.cat((guide, guide)) @@ -875,7 +875,7 @@ def test_diffusion_refonly( for step in sd15.steps: noise = torch.randn(2, 4, 64, 64, device=test_device) noised_guide = sd15.scheduler.add_noise(guide, noise, step) - sai.set_controlnet_condition(noised_guide) + refonly_adapter.set_controlnet_condition(noised_guide) x = sd15( x, step=step, @@ -903,7 +903,7 @@ def test_diffusion_inpainting_refonly( prompt = "" # unconditional clip_text_embedding = sd15.compute_clip_text_embedding(prompt) - sai = ReferenceOnlyControlAdapter(sd15.unet).inject() + refonly_adapter = ReferenceOnlyControlAdapter(sd15.unet).inject() sd15.set_num_inference_steps(n_steps) sd15.set_inpainting_conditions(target_image_inpainting_refonly, mask_image_inpainting_refonly) @@ -921,7 +921,7 @@ def test_diffusion_inpainting_refonly( # inpaint variation models") noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) - sai.set_controlnet_condition(noised_guide) + refonly_adapter.set_controlnet_condition(noised_guide) x = sd15( x, step=step, diff --git a/tests/foundationals/latent_diffusion/test_reference_only_control.py b/tests/foundationals/latent_diffusion/test_reference_only_control.py index 580fdb0..95201bd 100644 --- a/tests/foundationals/latent_diffusion/test_reference_only_control.py +++ b/tests/foundationals/latent_diffusion/test_reference_only_control.py @@ -13,9 +13,9 @@ from refiners.foundationals.latent_diffusion.cross_attention import CrossAttenti @torch.no_grad() -def test_sai_inject_eject() -> None: +def test_refonly_inject_eject() -> None: unet = SD1UNet(in_channels=9) - sai = ReferenceOnlyControlAdapter(unet) + adapter = ReferenceOnlyControlAdapter(unet) nb_cross_attention_blocks = len(list(unet.walk(CrossAttentionBlock))) assert nb_cross_attention_blocks > 0 @@ -26,21 +26,21 @@ def test_sai_inject_eject() -> None: assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == 0 with pytest.raises(AssertionError) as exc: - sai.eject() + adapter.eject() assert "not the first element" in str(exc.value) - sai.inject() + adapter.inject() - assert unet.parent == sai + assert unet.parent == adapter assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 1 assert len(list(unet.walk(SaveLayerNormAdapter))) == nb_cross_attention_blocks assert len(list(unet.walk(SelfAttentionInjectionAdapter))) == nb_cross_attention_blocks with pytest.raises(AssertionError) as exc: - sai.inject() + adapter.inject() assert "already injected" in str(exc.value) - sai.eject() + adapter.eject() assert unet.parent is None assert len(list(unet.walk(SelfAttentionInjectionPassthrough))) == 0