mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
make set_scale
for T2I-Adapter really dynamic
Before this change, `set_scale` had only an impact on the condition encoder. So calling `set_scale` after `set_condition_features` had no effect at runtime.
This commit is contained in:
parent
694661ee10
commit
9fbe86fbc9
|
@ -1,5 +1,3 @@
|
||||||
from typing import cast, Iterable
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
|
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
|
||||||
|
@ -13,9 +11,11 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||||
target: SD1UNet,
|
target: SD1UNet,
|
||||||
name: str,
|
name: str,
|
||||||
condition_encoder: ConditionEncoder | None = None,
|
condition_encoder: ConditionEncoder | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
weights: dict[str, Tensor] | None = None,
|
weights: dict[str, Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.residual_indices = (2, 5, 8, 11)
|
self.residual_indices = (2, 5, 8, 11)
|
||||||
|
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
target=target,
|
target=target,
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -24,23 +24,14 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||||
)
|
)
|
||||||
|
|
||||||
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
||||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||||
if n not in self.residual_indices:
|
block = self.target.DownBlocks[n]
|
||||||
continue
|
|
||||||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
block.insert_before_type(
|
block.insert_before_type(ResidualAccumulator, feat)
|
||||||
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
|
||||||
)
|
|
||||||
return super().inject(parent)
|
return super().inject(parent)
|
||||||
|
|
||||||
def eject(self: "SD1T2IAdapter") -> None:
|
def eject(self: "SD1T2IAdapter") -> None:
|
||||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
for n, feat in zip(self.residual_indices, self._features, strict=True):
|
||||||
if n not in self.residual_indices:
|
self.target.DownBlocks[n].remove(feat)
|
||||||
continue
|
|
||||||
t2i_layers = [
|
|
||||||
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
|
||||||
]
|
|
||||||
assert len(t2i_layers) == 1
|
|
||||||
block.remove(t2i_layers.pop())
|
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from typing import cast, Iterable
|
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
|
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
|
||||||
|
@ -14,9 +12,11 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
||||||
target: SDXLUNet,
|
target: SDXLUNet,
|
||||||
name: str,
|
name: str,
|
||||||
condition_encoder: ConditionEncoderXL | None = None,
|
condition_encoder: ConditionEncoderXL | None = None,
|
||||||
|
scale: float = 1.0,
|
||||||
weights: dict[str, Tensor] | None = None,
|
weights: dict[str, Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
|
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
|
||||||
|
self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
target=target,
|
target=target,
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -29,29 +29,20 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
||||||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
|
|
||||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
# Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below)
|
||||||
if n not in self.residual_indices:
|
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||||
continue
|
block = self.target.DownBlocks[n]
|
||||||
sanity_check_t2i(block)
|
sanity_check_t2i(block)
|
||||||
block.insert_before_type(
|
block.insert_before_type(ResidualAccumulator, feat)
|
||||||
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
|
||||||
)
|
|
||||||
sanity_check_t2i(self.target.MiddleBlock)
|
|
||||||
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
||||||
self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1))
|
sanity_check_t2i(self.target.MiddleBlock)
|
||||||
|
self.target.MiddleBlock.append(self._features[-1])
|
||||||
return super().inject(parent)
|
return super().inject(parent)
|
||||||
|
|
||||||
def eject(self: "SDXLT2IAdapter") -> None:
|
def eject(self: "SDXLT2IAdapter") -> None:
|
||||||
def eject_t2i(block: fl.Module) -> None:
|
# See `inject` re: `strict=False`
|
||||||
t2i_layers = [
|
for n, feat in zip(self.residual_indices, self._features, strict=False):
|
||||||
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
self.target.DownBlocks[n].remove(feat)
|
||||||
]
|
self.target.MiddleBlock.remove(self._features[-1])
|
||||||
assert len(t2i_layers) == 1
|
|
||||||
block.remove(t2i_layers.pop())
|
|
||||||
|
|
||||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
|
||||||
if n not in self.residual_indices:
|
|
||||||
continue
|
|
||||||
eject_t2i(block)
|
|
||||||
eject_t2i(self.target.MiddleBlock)
|
|
||||||
super().eject()
|
super().eject()
|
||||||
|
|
|
@ -117,13 +117,9 @@ class ConditionEncoder(fl.Chain):
|
||||||
)
|
)
|
||||||
for i in range(1, len(channels))
|
for i in range(1, len(channels))
|
||||||
),
|
),
|
||||||
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
fl.UseContext(context="t2iadapter", key="features"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]:
|
|
||||||
assert len(features) == 4
|
|
||||||
return tuple([x * self.scale for x in features])
|
|
||||||
|
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"t2iadapter": {"features": []}}
|
return {"t2iadapter": {"features": []}}
|
||||||
|
|
||||||
|
@ -157,23 +153,25 @@ class ConditionEncoderXL(ConditionEncoder, fl.Chain):
|
||||||
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||||
),
|
),
|
||||||
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
|
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
|
||||||
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
fl.UseContext(context="t2iadapter", key="features"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class T2IFeatures(fl.Residual):
|
class T2IFeatures(fl.Residual):
|
||||||
def __init__(self, name: str, index: int) -> None:
|
def __init__(self, name: str, index: int, scale: float = 1.0) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.index = index
|
self.index = index
|
||||||
|
self.scale = scale
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
|
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
|
||||||
func=lambda features: features[self.index]
|
func=lambda features: self.scale * features[self.index]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
|
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
|
||||||
|
_features: list[T2IFeatures] = []
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -207,7 +205,8 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
||||||
|
|
||||||
def set_scale(self, scale: float) -> None:
|
def set_scale(self, scale: float) -> None:
|
||||||
self.condition_encoder.scale = scale
|
for f in self._features:
|
||||||
|
f.scale = scale
|
||||||
|
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
||||||
|
|
Loading…
Reference in a new issue