mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
|
from typing import Iterator
|
||
|
|
||
|
import torch
|
||
|
import pytest
|
||
|
|
||
|
import refiners.fluxion.layers as fl
|
||
|
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
|
||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="module", params=[True, False])
|
||
|
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
||
|
with_parent: bool = request.param
|
||
|
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||
|
if with_parent:
|
||
|
fl.Chain(unet)
|
||
|
yield unet
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def test_single_controlnet(unet: SD1UNet) -> None:
|
||
|
original_parent = unet.parent
|
||
|
cn = SD1ControlnetAdapter(unet, name="cn")
|
||
|
|
||
|
assert unet.parent == original_parent
|
||
|
assert len(list(unet.walk(Controlnet))) == 0
|
||
|
|
||
|
with pytest.raises(ValueError) as exc:
|
||
|
cn.eject()
|
||
|
assert "not in" in str(exc.value)
|
||
|
|
||
|
cn.inject()
|
||
|
assert unet.parent == cn
|
||
|
assert len(list(unet.walk(Controlnet))) == 1
|
||
|
|
||
|
with pytest.raises(AssertionError) as exc:
|
||
|
cn.inject()
|
||
|
assert "already injected" in str(exc.value)
|
||
|
|
||
|
cn.eject()
|
||
|
assert unet.parent == original_parent
|
||
|
assert len(list(unet.walk(Controlnet))) == 0
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
||
|
original_parent = unet.parent
|
||
|
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||
|
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
|
||
|
|
||
|
assert unet.parent == cn2
|
||
|
assert unet in cn2
|
||
|
assert unet not in cn1
|
||
|
assert cn2.parent == cn1
|
||
|
assert cn2 in cn1
|
||
|
assert cn1.parent == original_parent
|
||
|
assert len(list(unet.walk(Controlnet))) == 2
|
||
|
assert cn1.target == unet
|
||
|
assert cn1.lookup_actual_target() == cn2
|
||
|
|
||
|
cn2.eject()
|
||
|
assert unet.parent == cn1
|
||
|
assert unet in cn2
|
||
|
assert cn2 not in cn1
|
||
|
assert unet in cn1
|
||
|
assert len(list(unet.walk(Controlnet))) == 1
|
||
|
|
||
|
cn1.eject()
|
||
|
assert unet.parent == original_parent
|
||
|
assert len(list(unet.walk(Controlnet))) == 0
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
||
|
original_parent = unet.parent
|
||
|
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||
|
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
|
||
|
|
||
|
cn1.eject()
|
||
|
assert cn2.parent == original_parent
|
||
|
assert unet.parent == cn2
|
||
|
|
||
|
cn2.eject()
|
||
|
assert unet.parent == original_parent
|
||
|
assert len(list(unet.walk(Controlnet))) == 0
|