check no two controlnets have the same name

This commit is contained in:
Pierre Chapuis 2023-09-01 16:28:30 +02:00
parent bd59790e08
commit e91e31ebd2
2 changed files with 16 additions and 2 deletions

View file

@ -72,7 +72,7 @@ class ConditionEncoder(Chain):
class Controlnet(Passthrough): class Controlnet(Passthrough):
structural_attrs = ["scale"] structural_attrs = ["scale", "name"]
def __init__( def __init__(
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
@ -84,6 +84,7 @@ class Controlnet(Passthrough):
It has to use the same context as the UNet: `unet` and `sampling`. It has to use the same context as the UNet: `unet` and `sampling`.
""" """
self.name = name
self.scale = scale self.scale = scale
super().__init__( super().__init__(
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype), TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
@ -159,7 +160,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter": def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
controlnet = self._controlnet[0] controlnet = self._controlnet[0]
assert controlnet not in self.target, f"{controlnet} is already injected" target_controlnets = [x for x in self.target if isinstance(x, Controlnet)]
assert controlnet not in target_controlnets, f"{controlnet} is already injected"
for cn in target_controlnets:
assert cn.name != self.name, f"Controlnet named {self.name} is already injected"
self.target.insert(0, controlnet) self.target.insert(0, controlnet)
return super().inject(parent) return super().inject(parent)

View file

@ -83,3 +83,13 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
cn2.eject() cn2.eject()
assert unet.parent == original_parent assert unet.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 0 assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
SD1ControlnetAdapter(unet, name="cnx").inject()
cn2 = SD1ControlnetAdapter(unet, name="cnx")
with pytest.raises(AssertionError) as exc:
cn2.inject()
assert "Controlnet named cnx is already injected" in str(exc.value)