mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
27 lines
931 B
Python
27 lines
931 B
Python
import pytest
|
|
import torch
|
|
|
|
from refiners.fluxion.utils import no_grad
|
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_XL
|
|
|
|
|
|
@no_grad()
|
|
@pytest.mark.parametrize("k_sd", [StableDiffusion_1, StableDiffusion_XL])
|
|
def test_set_self_attention_guidance(
|
|
k_sd: type[StableDiffusion_1] | type[StableDiffusion_XL], test_device: torch.device
|
|
):
|
|
sd = k_sd(device=test_device, dtype=torch.float16)
|
|
|
|
assert sd._find_sag_adapter() is None # type: ignore
|
|
sd.set_self_attention_guidance(enable=True, scale=0.42)
|
|
adapter = sd._find_sag_adapter() # type: ignore
|
|
assert adapter is not None
|
|
assert adapter.scale == 0.42
|
|
|
|
sd.set_self_attention_guidance(enable=True, scale=0.75)
|
|
assert sd._find_sag_adapter() == adapter # type: ignore
|
|
assert adapter.scale == 0.75
|
|
|
|
sd.set_self_attention_guidance(enable=False)
|
|
assert sd._find_sag_adapter() is None # type: ignore
|