mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
check no two controlnets have the same name
This commit is contained in:
parent
bd59790e08
commit
e91e31ebd2
|
@ -72,7 +72,7 @@ class ConditionEncoder(Chain):
|
||||||
|
|
||||||
|
|
||||||
class Controlnet(Passthrough):
|
class Controlnet(Passthrough):
|
||||||
structural_attrs = ["scale"]
|
structural_attrs = ["scale", "name"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||||
|
@ -84,6 +84,7 @@ class Controlnet(Passthrough):
|
||||||
|
|
||||||
It has to use the same context as the UNet: `unet` and `sampling`.
|
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||||
"""
|
"""
|
||||||
|
self.name = name
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||||
|
@ -159,7 +160,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
||||||
|
|
||||||
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
|
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
|
||||||
controlnet = self._controlnet[0]
|
controlnet = self._controlnet[0]
|
||||||
assert controlnet not in self.target, f"{controlnet} is already injected"
|
target_controlnets = [x for x in self.target if isinstance(x, Controlnet)]
|
||||||
|
assert controlnet not in target_controlnets, f"{controlnet} is already injected"
|
||||||
|
for cn in target_controlnets:
|
||||||
|
assert cn.name != self.name, f"Controlnet named {self.name} is already injected"
|
||||||
self.target.insert(0, controlnet)
|
self.target.insert(0, controlnet)
|
||||||
return super().inject(parent)
|
return super().inject(parent)
|
||||||
|
|
||||||
|
|
|
@ -83,3 +83,13 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
||||||
cn2.eject()
|
cn2.eject()
|
||||||
assert unet.parent == original_parent
|
assert unet.parent == original_parent
|
||||||
assert len(list(unet.walk(Controlnet))) == 0
|
assert len(list(unet.walk(Controlnet))) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
|
||||||
|
SD1ControlnetAdapter(unet, name="cnx").inject()
|
||||||
|
cn2 = SD1ControlnetAdapter(unet, name="cnx")
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as exc:
|
||||||
|
cn2.inject()
|
||||||
|
assert "Controlnet named cnx is already injected" in str(exc.value)
|
||||||
|
|
Loading…
Reference in a new issue