from typing import Iterator import torch import pytest import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet @pytest.fixture(scope="module", params=[True, False]) def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]: with_parent: bool = request.param unet = SD1UNet(in_channels=9, clip_embedding_dim=768) if with_parent: fl.Chain(unet) yield unet @torch.no_grad() def test_single_controlnet(unet: SD1UNet) -> None: original_parent = unet.parent cn = SD1ControlnetAdapter(unet, name="cn") assert unet.parent == original_parent assert len(list(unet.walk(Controlnet))) == 0 with pytest.raises(ValueError) as exc: cn.eject() assert "not in" in str(exc.value) cn.inject() assert unet.parent == cn assert len(list(unet.walk(Controlnet))) == 1 with pytest.raises(AssertionError) as exc: cn.inject() assert "already injected" in str(exc.value) cn.eject() assert unet.parent == original_parent assert len(list(unet.walk(Controlnet))) == 0 @torch.no_grad() def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: original_parent = unet.parent cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() cn2 = SD1ControlnetAdapter(unet, name="cn2").inject() assert unet.parent == cn2 assert unet in cn2 assert unet not in cn1 assert cn2.parent == cn1 assert cn2 in cn1 assert cn1.parent == original_parent assert len(list(unet.walk(Controlnet))) == 2 assert cn1.target == unet assert cn1.lookup_actual_target() == cn2 cn2.eject() assert unet.parent == cn1 assert unet in cn2 assert cn2 not in cn1 assert unet in cn1 assert len(list(unet.walk(Controlnet))) == 1 cn1.eject() assert unet.parent == original_parent assert len(list(unet.walk(Controlnet))) == 0 @torch.no_grad() def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: original_parent = unet.parent cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() cn2 = SD1ControlnetAdapter(unet, name="cn2").inject() cn1.eject() assert cn2.parent == original_parent assert unet.parent == cn2 cn2.eject() assert unet.parent == original_parent assert len(list(unet.walk(Controlnet))) == 0