mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
expose lookup_top_adapter
This commit is contained in:
parent
f4e9707297
commit
3c056e2231
|
@ -36,27 +36,6 @@ class Adapter(Generic[T]):
|
||||||
yield
|
yield
|
||||||
target._can_refresh_parent = _old_can_refresh_parent
|
target._can_refresh_parent = _old_can_refresh_parent
|
||||||
|
|
||||||
def lookup_actual_target(self) -> fl.Module:
|
|
||||||
# In general, the "actual target" is the target.
|
|
||||||
# This method deals with the edge case where the target
|
|
||||||
# is part of the replacement block and has been adapted by
|
|
||||||
# another adapter after this one. For instance, this is the
|
|
||||||
# case when stacking Controlnets.
|
|
||||||
assert isinstance(self, fl.Chain)
|
|
||||||
|
|
||||||
target_parent = self.find_parent(self.target)
|
|
||||||
if (target_parent is None) or (target_parent == self):
|
|
||||||
return self.target
|
|
||||||
|
|
||||||
# Lookup and return last adapter in parents tree (or target if none).
|
|
||||||
r, p = self.target, target_parent
|
|
||||||
while p != self:
|
|
||||||
if isinstance(p, Adapter):
|
|
||||||
r = p
|
|
||||||
assert p.parent, f"parent tree of {self} is broken"
|
|
||||||
p = p.parent
|
|
||||||
return r
|
|
||||||
|
|
||||||
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
def inject(self: TAdapter, parent: fl.Chain | None = None) -> TAdapter:
|
||||||
assert isinstance(self, fl.Chain)
|
assert isinstance(self, fl.Chain)
|
||||||
|
|
||||||
|
@ -87,7 +66,13 @@ class Adapter(Generic[T]):
|
||||||
|
|
||||||
def eject(self) -> None:
|
def eject(self) -> None:
|
||||||
assert isinstance(self, fl.Chain)
|
assert isinstance(self, fl.Chain)
|
||||||
actual_target = self.lookup_actual_target()
|
|
||||||
|
# In general, the "actual target" is the target.
|
||||||
|
# Here we deal with the edge case where the target
|
||||||
|
# is part of the replacement block and has been adapted by
|
||||||
|
# another adapter after this one. For instance, this is the
|
||||||
|
# case when stacking Controlnets.
|
||||||
|
actual_target = lookup_top_adapter(self, self.target)
|
||||||
|
|
||||||
if (parent := self.parent) is None:
|
if (parent := self.parent) is None:
|
||||||
if isinstance(actual_target, fl.ContextModule):
|
if isinstance(actual_target, fl.ContextModule):
|
||||||
|
@ -101,3 +86,19 @@ class Adapter(Generic[T]):
|
||||||
|
|
||||||
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
def _post_structural_copy(self: TAdapter, source: TAdapter) -> None:
|
||||||
self._target = [source.target]
|
self._target = [source.target]
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_top_adapter(top: fl.Chain, target: fl.Module) -> fl.Module:
|
||||||
|
"""Lookup and return last adapter in parents tree (or target if none)."""
|
||||||
|
|
||||||
|
target_parent = top.find_parent(target)
|
||||||
|
if (target_parent is None) or (target_parent == top):
|
||||||
|
return target
|
||||||
|
|
||||||
|
r, p = target, target_parent
|
||||||
|
while p != top:
|
||||||
|
if isinstance(p, Adapter):
|
||||||
|
r = p
|
||||||
|
assert p.parent, f"parent tree of {top} is broken"
|
||||||
|
p = p.parent
|
||||||
|
return r
|
||||||
|
|
|
@ -4,6 +4,7 @@ import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.adapters.adapter import lookup_top_adapter
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
|
from refiners.foundationals.latent_diffusion import SD1UNet, SD1ControlnetAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||||
|
|
||||||
|
@ -56,7 +57,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
||||||
assert cn1.parent == original_parent
|
assert cn1.parent == original_parent
|
||||||
assert len(list(unet.walk(Controlnet))) == 2
|
assert len(list(unet.walk(Controlnet))) == 2
|
||||||
assert cn1.target == unet
|
assert cn1.target == unet
|
||||||
assert cn1.lookup_actual_target() == cn2
|
assert lookup_top_adapter(cn1, cn1.target) == cn2
|
||||||
|
|
||||||
cn2.eject()
|
cn2.eject()
|
||||||
assert unet.parent == cn1
|
assert unet.parent == cn1
|
||||||
|
|
Loading…
Reference in a new issue