mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
check no two controlnets have the same name
This commit is contained in:
parent
bd59790e08
commit
e91e31ebd2
|
@ -72,7 +72,7 @@ class ConditionEncoder(Chain):
|
|||
|
||||
|
||||
class Controlnet(Passthrough):
|
||||
structural_attrs = ["scale"]
|
||||
structural_attrs = ["scale", "name"]
|
||||
|
||||
def __init__(
|
||||
self, name: str, scale: float = 1.0, device: Device | str | None = None, dtype: DType | None = None
|
||||
|
@ -84,6 +84,7 @@ class Controlnet(Passthrough):
|
|||
|
||||
It has to use the same context as the UNet: `unet` and `sampling`.
|
||||
"""
|
||||
self.name = name
|
||||
self.scale = scale
|
||||
super().__init__(
|
||||
TimestepEncoder(context_key=f"timestep_embedding_{name}", device=device, dtype=dtype),
|
||||
|
@ -159,7 +160,10 @@ class SD1ControlnetAdapter(Chain, Adapter[SD1UNet]):
|
|||
|
||||
def inject(self: "SD1ControlnetAdapter", parent: Chain | None = None) -> "SD1ControlnetAdapter":
|
||||
controlnet = self._controlnet[0]
|
||||
assert controlnet not in self.target, f"{controlnet} is already injected"
|
||||
target_controlnets = [x for x in self.target if isinstance(x, Controlnet)]
|
||||
assert controlnet not in target_controlnets, f"{controlnet} is already injected"
|
||||
for cn in target_controlnets:
|
||||
assert cn.name != self.name, f"Controlnet named {self.name} is already injected"
|
||||
self.target.insert(0, controlnet)
|
||||
return super().inject(parent)
|
||||
|
||||
|
|
|
@ -83,3 +83,13 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
|||
cn2.eject()
|
||||
assert unet.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
|
||||
SD1ControlnetAdapter(unet, name="cnx").inject()
|
||||
cn2 = SD1ControlnetAdapter(unet, name="cnx")
|
||||
|
||||
with pytest.raises(AssertionError) as exc:
|
||||
cn2.inject()
|
||||
assert "Controlnet named cnx is already injected" in str(exc.value)
|
||||
|
|
Loading…
Reference in a new issue