make set_scale for T2I-Adapter really dynamic

Before this change, `set_scale` had only an impact on the condition
encoder. So calling `set_scale` after `set_condition_features` had no
effect at runtime.
This commit is contained in:
Cédric Deltheil 2023-10-04 11:08:30 +02:00 committed by Cédric Deltheil
parent 694661ee10
commit 9fbe86fbc9
3 changed files with 28 additions and 47 deletions

View file

@ -1,5 +1,3 @@
from typing import cast, Iterable
from torch import Tensor
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
@ -13,9 +11,11 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
target: SD1UNet,
name: str,
condition_encoder: ConditionEncoder | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
self.residual_indices = (2, 5, 8, 11)
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
super().__init__(
target=target,
name=name,
@ -24,23 +24,14 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
)
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
for n, feat in zip(self.residual_indices, self._features, strict=True):
block = self.target.DownBlocks[n]
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
block.insert_before_type(
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
)
block.insert_before_type(ResidualAccumulator, feat)
return super().inject(parent)
def eject(self: "SD1T2IAdapter") -> None:
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
t2i_layers = [
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
]
assert len(t2i_layers) == 1
block.remove(t2i_layers.pop())
for n, feat in zip(self.residual_indices, self._features, strict=True):
self.target.DownBlocks[n].remove(feat)
super().eject()

View file

@ -1,5 +1,3 @@
from typing import cast, Iterable
from torch import Tensor
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
@ -14,9 +12,11 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
target: SDXLUNet,
name: str,
condition_encoder: ConditionEncoderXL | None = None,
scale: float = 1.0,
weights: dict[str, Tensor] | None = None,
) -> None:
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
super().__init__(
target=target,
name=name,
@ -29,29 +29,20 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
for n, feat in zip(self.residual_indices, self._features, strict=False):
block = self.target.DownBlocks[n]
sanity_check_t2i(block)
block.insert_before_type(
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
)
sanity_check_t2i(self.target.MiddleBlock)
block.insert_before_type(ResidualAccumulator, feat)
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1))
sanity_check_t2i(self.target.MiddleBlock)
self.target.MiddleBlock.append(self._features[-1])
return super().inject(parent)
def eject(self: "SDXLT2IAdapter") -> None:
def eject_t2i(block: fl.Module) -> None:
t2i_layers = [
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
]
assert len(t2i_layers) == 1
block.remove(t2i_layers.pop())
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
eject_t2i(block)
eject_t2i(self.target.MiddleBlock)
# See `inject` re: `strict=False`
for n, feat in zip(self.residual_indices, self._features, strict=False):
self.target.DownBlocks[n].remove(feat)
self.target.MiddleBlock.remove(self._features[-1])
super().eject()

View file

@ -117,13 +117,9 @@ class ConditionEncoder(fl.Chain):
)
for i in range(1, len(channels))
),
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
fl.UseContext(context="t2iadapter", key="features"),
)
def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]:
assert len(features) == 4
return tuple([x * self.scale for x in features])
def init_context(self) -> Contexts:
return {"t2iadapter": {"features": []}}
@ -157,23 +153,25 @@ class ConditionEncoderXL(ConditionEncoder, fl.Chain):
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
),
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
fl.UseContext(context="t2iadapter", key="features"),
)
class T2IFeatures(fl.Residual):
def __init__(self, name: str, index: int) -> None:
def __init__(self, name: str, index: int, scale: float = 1.0) -> None:
self.name = name
self.index = index
self.scale = scale
super().__init__(
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
func=lambda features: features[self.index]
func=lambda features: self.scale * features[self.index]
)
)
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
_features: list[T2IFeatures] = []
def __init__(
self,
@ -207,7 +205,8 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
def set_scale(self, scale: float) -> None:
self.condition_encoder.scale = scale
for f in self._features:
f.scale = scale
def init_context(self) -> Contexts:
return {"t2iadapter": {f"condition_features_{self.name}": None}}