diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py index d263afc..7cbb49c 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py @@ -1,5 +1,3 @@ -from typing import cast, Iterable - from torch import Tensor from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder @@ -13,9 +11,11 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]): target: SD1UNet, name: str, condition_encoder: ConditionEncoder | None = None, + scale: float = 1.0, weights: dict[str, Tensor] | None = None, ) -> None: self.residual_indices = (2, 5, 8, 11) + self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)] super().__init__( target=target, name=name, @@ -24,23 +24,14 @@ class SD1T2IAdapter(T2IAdapter[SD1UNet]): ) def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter": - for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): - if n not in self.residual_indices: - continue + for n, feat in zip(self.residual_indices, self._features, strict=True): + block = self.target.DownBlocks[n] for t2i_layer in block.layers(layer_type=T2IFeatures): assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" - block.insert_before_type( - ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n)) - ) + block.insert_before_type(ResidualAccumulator, feat) return super().inject(parent) def eject(self: "SD1T2IAdapter") -> None: - for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): - if n not in self.residual_indices: - continue - t2i_layers = [ - t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name - ] - assert len(t2i_layers) == 1 - block.remove(t2i_layers.pop()) + for n, feat in zip(self.residual_indices, self._features, strict=True): + self.target.DownBlocks[n].remove(feat) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py index 65be7e4..af3830d 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py @@ -1,5 +1,3 @@ -from typing import cast, Iterable - from torch import Tensor from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL @@ -14,9 +12,11 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]): target: SDXLUNet, name: str, condition_encoder: ConditionEncoderXL | None = None, + scale: float = 1.0, weights: dict[str, Tensor] | None = None, ) -> None: self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`) + self._features = [T2IFeatures(name=name, index=i, scale=scale) for i in range(4)] super().__init__( target=target, name=name, @@ -29,29 +29,20 @@ class SDXLT2IAdapter(T2IAdapter[SDXLUNet]): for t2i_layer in block.layers(layer_type=T2IFeatures): assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" - for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): - if n not in self.residual_indices: - continue + # Note: `strict=False` because `residual_indices` is shorter than `_features` due to MiddleBlock (see below) + for n, feat in zip(self.residual_indices, self._features, strict=False): + block = self.target.DownBlocks[n] sanity_check_t2i(block) - block.insert_before_type( - ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n)) - ) - sanity_check_t2i(self.target.MiddleBlock) + block.insert_before_type(ResidualAccumulator, feat) + # Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append - self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1)) + sanity_check_t2i(self.target.MiddleBlock) + self.target.MiddleBlock.append(self._features[-1]) return super().inject(parent) def eject(self: "SDXLT2IAdapter") -> None: - def eject_t2i(block: fl.Module) -> None: - t2i_layers = [ - t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name - ] - assert len(t2i_layers) == 1 - block.remove(t2i_layers.pop()) - - for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): - if n not in self.residual_indices: - continue - eject_t2i(block) - eject_t2i(self.target.MiddleBlock) + # See `inject` re: `strict=False` + for n, feat in zip(self.residual_indices, self._features, strict=False): + self.target.DownBlocks[n].remove(feat) + self.target.MiddleBlock.remove(self._features[-1]) super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py index ab2ef8b..1efdb52 100644 --- a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py @@ -117,13 +117,9 @@ class ConditionEncoder(fl.Chain): ) for i in range(1, len(channels)) ), - fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs), + fl.UseContext(context="t2iadapter", key="features"), ) - def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]: - assert len(features) == 4 - return tuple([x * self.scale for x in features]) - def init_context(self) -> Contexts: return {"t2iadapter": {"features": []}} @@ -157,23 +153,25 @@ class ConditionEncoderXL(ConditionEncoder, fl.Chain): channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype ), StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype), - fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs), + fl.UseContext(context="t2iadapter", key="features"), ) class T2IFeatures(fl.Residual): - def __init__(self, name: str, index: int) -> None: + def __init__(self, name: str, index: int, scale: float = 1.0) -> None: self.name = name self.index = index + self.scale = scale super().__init__( fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose( - func=lambda features: features[self.index] + func=lambda features: self.scale * features[self.index] ) ) class T2IAdapter(Generic[T], fl.Chain, Adapter[T]): _condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration + _features: list[T2IFeatures] = [] def __init__( self, @@ -207,7 +205,8 @@ class T2IAdapter(Generic[T], fl.Chain, Adapter[T]): self.set_context("t2iadapter", {f"condition_features_{self.name}": features}) def set_scale(self, scale: float) -> None: - self.condition_encoder.scale = scale + for f in self._features: + f.scale = scale def init_context(self) -> Contexts: return {"t2iadapter": {f"condition_features_{self.name}": None}}