diff --git a/src/refiners/fluxion/adapters/adapter.py b/src/refiners/fluxion/adapters/adapter.py index ea838f9..03a5146 100644 --- a/src/refiners/fluxion/adapters/adapter.py +++ b/src/refiners/fluxion/adapters/adapter.py @@ -36,27 +36,6 @@ class Adapter(Generic[T]): yield target._can_refresh_parent = _old_can_refresh_parent - def lookup_actual_target(self) -> fl.Module: - # In general, the "actual target" is the target. - # This method deals with the edge case where the target - # is part of the replacement block and has been adapted by - # another adapter after this one. For instance, this is the - # case when stacking Controlnets. - assert isinstance(self, fl.Chain) - - target_parent = self.find_parent(self.target) - if (target_parent is None) or (target_parent == self): - return self.target - - # Lookup and return last adapter in parents tree (or target if none). - r, p = self.target, target_parent - while p != self: - if isinstance(p, Adapter): - r = p - assert p.parent, f"parent tree of {self} is broken" - p = p.parent - return r - def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter: assert isinstance(self, fl.Chain) @@ -87,7 +66,13 @@ class Adapter(Generic[T]): def eject(self) -> None: assert isinstance(self, fl.Chain) - actual_target = self.lookup_actual_target() + + # In general, the "actual target" is the target. + # Here we deal with the edge case where the target + # is part of the replacement block and has been adapted by + # another adapter after this one. For instance, this is the + # case when stacking Controlnets. + actual_target = lookup_top_adapter(self, self.target) if (parent := self.parent) is None: if isinstance(actual_target, fl.ContextModule): @@ -101,3 +86,19 @@ class Adapter(Generic[T]): def _post_structural_copy(self: TAdapter, source: TAdapter) -> None: self._target = [source.target] + + +def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module: + """Lookup and return last adapter in parents tree (or target if none).""" + + target_parent = top.find_parent(target) + if (target_parent is None) or (target_parent == top): + return target + + r, p = target, target_parent + while p != top: + if isinstance(p, Adapter): + r = p + assert p.parent, f"parent tree of {top} is broken" + p = p.parent + return r diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py index b89b384..92f7662 100644 --- a/tests/foundationals/latent_diffusion/test_controlnet.py +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -4,6 +4,7 @@ import torch import pytest import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.adapter import lookup_top_adapter from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet @@ -56,7 +57,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: assert cn1.parent == original_parent assert len(list(unet.walk(Controlnet))) == 2 assert cn1.target == unet - assert cn1.lookup_actual_target() == cn2 + assert lookup_top_adapter(cn1, cn1.target) == cn2 cn2.eject() assert unet.parent == cn1