From e91e31ebd2035a79e78286db2ca5b295d83bb9d1 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 1 Sep 2023 16:28:30 +0200 Subject: [PATCH] check no two controlnets have the same name --- .../latent_diffusion/stable_diffusion_1/controlnet.py | 8 ++++++-- .../foundationals/latent_diffusion/test_controlnet.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py index 2439985..6a59f73 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/controlnet.py @@ -72,7 +72,7 @@ class ConditionEncoder(Chain): class Controlnet(Passthrough): - structural_attrs = ["scale"] + structural_attrs = ["scale", "name"] def __init__( self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None @@ -84,6 +84,7 @@ class Controlnet(Passthrough): It has to use the same context as the UNet: `unet` and `sampling`. """ + self.name = name self.scale = scale super().__init__( TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), @@ -159,7 +160,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]): def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter": controlnet = self._controlnet[0] - assert controlnet not in self.target, f"{controlnet} is already injected" + target_controlnets = [x for x in self.target if isinstance(x, Controlnet)] + assert controlnet not in target_controlnets, f"{controlnet} is already injected" + for cn in target_controlnets: + assert cn.name != self.name, f"Controlnet named {self.name} is already injected" self.target.insert(0, controlnet) return super().inject(parent) diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py index a9de8dc..58f4050 100644 --- a/tests/foundationals/latent_diffusion/test_controlnet.py +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -83,3 +83,13 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: cn2.eject() assert unet.parent == original_parent assert len(list(unet.walk(Controlnet))) == 0 + + +@torch.no_grad() +def test_two_controlnets_same_name(unet: SD1UNet) -> None: + SD1ControlnetAdapter(unet, name="cnx").inject() + cn2 = SD1ControlnetAdapter(unet, name="cnx") + + with pytest.raises(AssertionError) as exc: + cn2.inject() + assert "Controlnet named cnx is already injected" in str(exc.value)