test IP adapter scale setter

This commit is contained in:
Pierre Chapuis 2024-01-30 15:07:37 +01:00 committed by Cédric Deltheil
parent 8341d3a74b
commit f4ed7254fa

View file

@ -3,8 +3,10 @@ from typing import overload
import pytest
import torch
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1IPAdapter, SD1UNet, SDXLIPAdapter, SDXLUNet
from refiners.foundationals.latent_diffusion.image_prompt import ImageCrossAttention
@overload
@ -35,3 +37,23 @@ def test_inject_eject(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch
assert repr(unet) != initial_repr
adapter.eject()
assert repr(unet) == initial_repr
@no_grad()
@pytest.mark.parametrize("k_unet", [SD1UNet, SDXLUNet])
def test_scale(k_unet: type[SD1UNet] | type[SDXLUNet], test_device: torch.device):
unet = k_unet(in_channels=4, device=test_device, dtype=torch.float16)
adapter = new_adapter(unet).inject()
def predicate(m: fl.Module, p: fl.Chain) -> bool:
return isinstance(p, ImageCrossAttention) and isinstance(m, fl.Multiply)
for m, _ in unet.walk(predicate):
assert isinstance(m, fl.Multiply)
assert m.scale == 1.0
adapter.scale = 0.42
assert adapter.scale == 0.42
for m, _ in unet.walk(predicate):
assert isinstance(m, fl.Multiply)
assert m.scale == 0.42