mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
0f476ea18b
This generalizes the Adapter abstraction to higher-level constructs such as high-level LoRA (targeting e.g. the SD UNet), ControlNet and Reference-Only Control. Some adapters now work by adapting child models with "sub-adapters" that they inject / eject when needed.
86 lines
2.4 KiB
Python
86 lines
2.4 KiB
Python
from typing import Iterator
|
|
|
|
import torch
|
|
import pytest
|
|
|
|
import refiners.fluxion.layers as fl
|
|
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
|
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
|
|
|
|
|
@pytest.fixture(scope="module", params=[True, False])
|
|
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
|
with_parent: bool = request.param
|
|
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
|
if with_parent:
|
|
fl.Chain(unet)
|
|
yield unet
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_single_controlnet(unet: SD1UNet) -> None:
|
|
original_parent = unet.parent
|
|
cn = SD1ControlnetAdapter(unet, name="cn")
|
|
|
|
assert unet.parent == original_parent
|
|
assert len(list(unet.walk(Controlnet))) == 0
|
|
|
|
with pytest.raises(ValueError) as exc:
|
|
cn.eject()
|
|
assert "not in" in str(exc.value)
|
|
|
|
cn.inject()
|
|
assert unet.parent == cn
|
|
assert len(list(unet.walk(Controlnet))) == 1
|
|
|
|
with pytest.raises(AssertionError) as exc:
|
|
cn.inject()
|
|
assert "already injected" in str(exc.value)
|
|
|
|
cn.eject()
|
|
assert unet.parent == original_parent
|
|
assert len(list(unet.walk(Controlnet))) == 0
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
|
original_parent = unet.parent
|
|
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
|
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
|
|
|
|
assert unet.parent == cn2
|
|
assert unet in cn2
|
|
assert unet not in cn1
|
|
assert cn2.parent == cn1
|
|
assert cn2 in cn1
|
|
assert cn1.parent == original_parent
|
|
assert len(list(unet.walk(Controlnet))) == 2
|
|
assert cn1.target == unet
|
|
assert cn1.lookup_actual_target() == cn2
|
|
|
|
cn2.eject()
|
|
assert unet.parent == cn1
|
|
assert unet in cn2
|
|
assert cn2 not in cn1
|
|
assert unet in cn1
|
|
assert len(list(unet.walk(Controlnet))) == 1
|
|
|
|
cn1.eject()
|
|
assert unet.parent == original_parent
|
|
assert len(list(unet.walk(Controlnet))) == 0
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
|
original_parent = unet.parent
|
|
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
|
cn2 = SD1ControlnetAdapter(unet, name="cn2").inject()
|
|
|
|
cn1.eject()
|
|
assert cn2.parent == original_parent
|
|
assert unet.parent == cn2
|
|
|
|
cn2.eject()
|
|
assert unet.parent == original_parent
|
|
assert len(list(unet.walk(Controlnet))) == 0
|