expose lookup_top_adapter

This commit is contained in:
Pierre Chapuis 2023-09-08 18:25:30 +02:00
parent f4e9707297
commit 3c056e2231
2 changed files with 25 additions and 23 deletions

View file

@ -36,27 +36,6 @@ class Adapter(Generic[T]):
yield
target._can_refresh_parent = _old_can_refresh_parent
def lookup_actual_target(self) -> fl.Module:
# In general, the "actual target" is the target.
# This method deals with the edge case where the target
# is part of the replacement block and has been adapted by
# another adapter after this one. For instance, this is the
# case when stacking Controlnets.
assert isinstance(self, fl.Chain)
target_parent = self.find_parent(self.target)
if (target_parent is None) or (target_parent == self):
return self.target
# Lookup and return last adapter in parents tree (or target if none).
r, p = self.target, target_parent
while p != self:
if isinstance(p, Adapter):
r = p
assert p.parent, f"parent tree of {self} is broken"
p = p.parent
return r
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
assert isinstance(self, fl.Chain)
@ -87,7 +66,13 @@ class Adapter(Generic[T]):
def eject(self) -> None:
assert isinstance(self, fl.Chain)
actual_target = self.lookup_actual_target()
# In general, the "actual target" is the target.
# Here we deal with the edge case where the target
# is part of the replacement block and has been adapted by
# another adapter after this one. For instance, this is the
# case when stacking Controlnets.
actual_target = lookup_top_adapter(self, self.target)
if (parent := self.parent) is None:
if isinstance(actual_target, fl.ContextModule):
@ -101,3 +86,19 @@ class Adapter(Generic[T]):
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
self._target = [source.target]
def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module:
"""Lookup and return last adapter in parents tree (or target if none)."""
target_parent = top.find_parent(target)
if (target_parent is None) or (target_parent == top):
return target
r, p = target, target_parent
while p != top:
if isinstance(p, Adapter):
r = p
assert p.parent, f"parent tree of {top} is broken"
p = p.parent
return r

View file

@ -4,6 +4,7 @@ import torch
import pytest
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import lookup_top_adapter
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
@ -56,7 +57,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
assert cn1.parent == original_parent
assert len(list(unet.walk(Controlnet))) == 2
assert cn1.target == unet
assert cn1.lookup_actual_target() == cn2
assert lookup_top_adapter(cn1, cn1.target) == cn2
cn2.eject()
assert unet.parent == cn1