test SAG setter

This commit is contained in:
Pierre Chapuis 2024-01-30 15:22:34 +01:00 committed by Cédric Deltheil
parent f4ed7254fa
commit df843f5226
2 changed files with 28 additions and 0 deletions

View file

@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLIPAdapter,
SDXLT2IAdapter,
SDXLUNet,
StableDiffusion_XL,
)
__all__ = [
@ -37,4 +38,5 @@ __all__ = [
"CLIPTextEncoderL",
"LatentDiffusionAutoencoder",
"SDFreeUAdapter",
"StableDiffusion_XL",
]

View file

@ -0,0 +1,26 @@
import pytest
import torch
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_XL
@no_grad()
@pytest.mark.parametrize("k_sd", [StableDiffusion_1, StableDiffusion_XL])
def test_set_self_attention_guidance(
k_sd: type[StableDiffusion_1] | type[StableDiffusion_XL], test_device: torch.device
):
sd = k_sd(device=test_device, dtype=torch.float16)
assert sd._find_sag_adapter() is None # type: ignore
sd.set_self_attention_guidance(enable=True, scale=0.42)
adapter = sd._find_sag_adapter() # type: ignore
assert adapter is not None
assert adapter.scale == 0.42
sd.set_self_attention_guidance(enable=True, scale=0.75)
assert sd._find_sag_adapter() == adapter # type: ignore
assert adapter.scale == 0.75
sd.set_self_attention_guidance(enable=False)
assert sd._find_sag_adapter() is None # type: ignore