mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
test SAG setter
This commit is contained in:
parent
f4ed7254fa
commit
df843f5226
|
@ -19,6 +19,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
|||
SDXLIPAdapter,
|
||||
SDXLT2IAdapter,
|
||||
SDXLUNet,
|
||||
StableDiffusion_XL,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -37,4 +38,5 @@ __all__ = [
|
|||
"CLIPTextEncoderL",
|
||||
"LatentDiffusionAutoencoder",
|
||||
"SDFreeUAdapter",
|
||||
"StableDiffusion_XL",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_XL
|
||||
|
||||
|
||||
@no_grad()
|
||||
@pytest.mark.parametrize("k_sd", [StableDiffusion_1, StableDiffusion_XL])
|
||||
def test_set_self_attention_guidance(
|
||||
k_sd: type[StableDiffusion_1] | type[StableDiffusion_XL], test_device: torch.device
|
||||
):
|
||||
sd = k_sd(device=test_device, dtype=torch.float16)
|
||||
|
||||
assert sd._find_sag_adapter() is None # type: ignore
|
||||
sd.set_self_attention_guidance(enable=True, scale=0.42)
|
||||
adapter = sd._find_sag_adapter() # type: ignore
|
||||
assert adapter is not None
|
||||
assert adapter.scale == 0.42
|
||||
|
||||
sd.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
assert sd._find_sag_adapter() == adapter # type: ignore
|
||||
assert adapter.scale == 0.75
|
||||
|
||||
sd.set_self_attention_guidance(enable=False)
|
||||
assert sd._find_sag_adapter() is None # type: ignore
|
Loading…
Reference in a new issue