mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
expose lookup_top_adapter
This commit is contained in:
parent
f4e9707297
commit
3c056e2231
|
@ -36,27 +36,6 @@ class Adapter(Generic[T]):
|
|||
yield
|
||||
target._can_refresh_parent = _old_can_refresh_parent
|
||||
|
||||
def lookup_actual_target(self) -> fl.Module:
|
||||
# In general, the "actual target" is the target.
|
||||
# This method deals with the edge case where the target
|
||||
# is part of the replacement block and has been adapted by
|
||||
# another adapter after this one. For instance, this is the
|
||||
# case when stacking Controlnets.
|
||||
assert isinstance(self, fl.Chain)
|
||||
|
||||
target_parent = self.find_parent(self.target)
|
||||
if (target_parent is None) or (target_parent == self):
|
||||
return self.target
|
||||
|
||||
# Lookup and return last adapter in parents tree (or target if none).
|
||||
r, p = self.target, target_parent
|
||||
while p != self:
|
||||
if isinstance(p, Adapter):
|
||||
r = p
|
||||
assert p.parent, f"parent tree of {self} is broken"
|
||||
p = p.parent
|
||||
return r
|
||||
|
||||
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
||||
assert isinstance(self, fl.Chain)
|
||||
|
||||
|
@ -87,7 +66,13 @@ class Adapter(Generic[T]):
|
|||
|
||||
def eject(self) -> None:
|
||||
assert isinstance(self, fl.Chain)
|
||||
actual_target = self.lookup_actual_target()
|
||||
|
||||
# In general, the "actual target" is the target.
|
||||
# Here we deal with the edge case where the target
|
||||
# is part of the replacement block and has been adapted by
|
||||
# another adapter after this one. For instance, this is the
|
||||
# case when stacking Controlnets.
|
||||
actual_target = lookup_top_adapter(self, self.target)
|
||||
|
||||
if (parent := self.parent) is None:
|
||||
if isinstance(actual_target, fl.ContextModule):
|
||||
|
@ -101,3 +86,19 @@ class Adapter(Generic[T]):
|
|||
|
||||
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
||||
self._target = [source.target]
|
||||
|
||||
|
||||
def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module:
|
||||
"""Lookup and return last adapter in parents tree (or target if none)."""
|
||||
|
||||
target_parent = top.find_parent(target)
|
||||
if (target_parent is None) or (target_parent == top):
|
||||
return target
|
||||
|
||||
r, p = target, target_parent
|
||||
while p != top:
|
||||
if isinstance(p, Adapter):
|
||||
r = p
|
||||
assert p.parent, f"parent tree of {top} is broken"
|
||||
p = p.parent
|
||||
return r
|
||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
|||
import pytest
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import lookup_top_adapter
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
|
@ -56,7 +57,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
|||
assert cn1.parent == original_parent
|
||||
assert len(list(unet.walk(Controlnet))) == 2
|
||||
assert cn1.target == unet
|
||||
assert cn1.lookup_actual_target() == cn2
|
||||
assert lookup_top_adapter(cn1, cn1.target) == cn2
|
||||
|
||||
cn2.eject()
|
||||
assert unet.parent == cn1
|
||||
|
|
Loading…
Reference in a new issue