From df843f52264154769a226365d8f52a9e27e61076 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 30 Jan 2024 15:22:34 +0100 Subject: [PATCH] test SAG setter --- .../latent_diffusion/__init__.py | 2 ++ .../test_self_attention_guidance.py | 26 +++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 tests/foundationals/latent_diffusion/test_self_attention_guidance.py diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index ce47a46..35e1b9a 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( SDXLIPAdapter, SDXLT2IAdapter, SDXLUNet, + StableDiffusion_XL, ) __all__ = [ @@ -37,4 +38,5 @@ __all__ = [ "CLIPTextEncoderL", "LatentDiffusionAutoencoder", "SDFreeUAdapter", + "StableDiffusion_XL", ] diff --git a/tests/foundationals/latent_diffusion/test_self_attention_guidance.py b/tests/foundationals/latent_diffusion/test_self_attention_guidance.py new file mode 100644 index 0000000..ec920fb --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_self_attention_guidance.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from refiners.fluxion.utils import no_grad +from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_XL + + +@no_grad() +@pytest.mark.parametrize("k_sd", [StableDiffusion_1, StableDiffusion_XL]) +def test_set_self_attention_guidance( + k_sd: type[StableDiffusion_1] | type[StableDiffusion_XL], test_device: torch.device +): + sd = k_sd(device=test_device, dtype=torch.float16) + + assert sd._find_sag_adapter() is None # type: ignore + sd.set_self_attention_guidance(enable=True, scale=0.42) + adapter = sd._find_sag_adapter() # type: ignore + assert adapter is not None + assert adapter.scale == 0.42 + + sd.set_self_attention_guidance(enable=True, scale=0.75) + assert sd._find_sag_adapter() == adapter # type: ignore + assert adapter.scale == 0.75 + + sd.set_self_attention_guidance(enable=False) + assert sd._find_sag_adapter() is None # type: ignore