mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
make set_scale
for T2I-Adapter really dynamic
Before this change, `set_scale` had only an impact on the condition encoder. So calling `set_scale` after `set_condition_features` had no effect at runtime.
This commit is contained in:
parent
694661ee10
commit
9fbe86fbc9
|
@ -1,5 +1,3 @@
|
|||
from typing import cast, Iterable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
|
||||
|
@ -13,9 +11,11 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
|||
target: SD1UNet,
|
||||
name: str,
|
||||
condition_encoder: ConditionEncoder | None = None,
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
self.residual_indices = (2, 5, 8, 11)
|
||||
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
|
||||
super().__init__(
|
||||
target=target,
|
||||
name=name,
|
||||
|
@ -24,23 +24,14 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
|||
)
|
||||
|
||||
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||
block = self.target.DownBlocks[n]
|
||||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||
block.insert_before_type(
|
||||
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
||||
)
|
||||
block.insert_before_type(ResidualAccumulator, feat)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self: "SD1T2IAdapter") -> None:
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
t2i_layers = [
|
||||
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
||||
]
|
||||
assert len(t2i_layers) == 1
|
||||
block.remove(t2i_layers.pop())
|
||||
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||
self.target.DownBlocks[n].remove(feat)
|
||||
super().eject()
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from typing import cast, Iterable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
|
||||
|
@ -14,9 +12,11 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
|||
target: SDXLUNet,
|
||||
name: str,
|
||||
condition_encoder: ConditionEncoderXL | None = None,
|
||||
scale: float = 1.0,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
|
||||
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
|
||||
super().__init__(
|
||||
target=target,
|
||||
name=name,
|
||||
|
@ -29,29 +29,20 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
|||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
|
||||
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||
block = self.target.DownBlocks[n]
|
||||
sanity_check_t2i(block)
|
||||
block.insert_before_type(
|
||||
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
||||
)
|
||||
sanity_check_t2i(self.target.MiddleBlock)
|
||||
block.insert_before_type(ResidualAccumulator, feat)
|
||||
|
||||
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
||||
self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1))
|
||||
sanity_check_t2i(self.target.MiddleBlock)
|
||||
self.target.MiddleBlock.append(self._features[-1])
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self: "SDXLT2IAdapter") -> None:
|
||||
def eject_t2i(block: fl.Module) -> None:
|
||||
t2i_layers = [
|
||||
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
||||
]
|
||||
assert len(t2i_layers) == 1
|
||||
block.remove(t2i_layers.pop())
|
||||
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
eject_t2i(block)
|
||||
eject_t2i(self.target.MiddleBlock)
|
||||
# See `inject` re: `strict=False`
|
||||
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||
self.target.DownBlocks[n].remove(feat)
|
||||
self.target.MiddleBlock.remove(self._features[-1])
|
||||
super().eject()
|
||||
|
|
|
@ -117,13 +117,9 @@ class ConditionEncoder(fl.Chain):
|
|||
)
|
||||
for i in range(1, len(channels))
|
||||
),
|
||||
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
||||
fl.UseContext(context="t2iadapter", key="features"),
|
||||
)
|
||||
|
||||
def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]:
|
||||
assert len(features) == 4
|
||||
return tuple([x * self.scale for x in features])
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"t2iadapter": {"features": []}}
|
||||
|
||||
|
@ -157,23 +153,25 @@ class ConditionEncoderXL(ConditionEncoder, fl.Chain):
|
|||
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||
),
|
||||
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
|
||||
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
||||
fl.UseContext(context="t2iadapter", key="features"),
|
||||
)
|
||||
|
||||
|
||||
class T2IFeatures(fl.Residual):
|
||||
def __init__(self, name: str, index: int) -> None:
|
||||
def __init__(self, name: str, index: int, scale: float = 1.0) -> None:
|
||||
self.name = name
|
||||
self.index = index
|
||||
self.scale = scale
|
||||
super().__init__(
|
||||
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
|
||||
func=lambda features: features[self.index]
|
||||
func=lambda features: self.scale * features[self.index]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
|
||||
_features: list[T2IFeatures] = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -207,7 +205,8 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self.condition_encoder.scale = scale
|
||||
for f in self._features:
|
||||
f.scale = scale
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
||||
|
|
Loading…
Reference in a new issue