diff --git a/src/refiners/foundationals/latent_diffusion/__init__.py b/src/refiners/foundationals/latent_diffusion/__init__.py index ca07e19..b301aac 100644 --- a/src/refiners/foundationals/latent_diffusion/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/__init__.py @@ -11,11 +11,13 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import ( SD1UNet, SD1ControlnetAdapter, SD1IPAdapter, + SD1T2IAdapter, ) from refiners.foundationals.latent_diffusion.stable_diffusion_xl import ( SDXLUNet, DoubleTextEncoder, SDXLIPAdapter, + SDXLT2IAdapter, ) @@ -25,9 +27,11 @@ __all__ = [ "SD1UNet", "SD1ControlnetAdapter", "SD1IPAdapter", + "SD1T2IAdapter", "SDXLUNet", "DoubleTextEncoder", "SDXLIPAdapter", + "SDXLT2IAdapter", "DPMSolver", "Scheduler", "CLIPTextEncoderL", diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py index b092398..8b94b69 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/__init__.py @@ -5,6 +5,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( ) from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter __all__ = [ "StableDiffusion_1", @@ -12,4 +13,5 @@ __all__ = [ "SD1UNet", "SD1ControlnetAdapter", "SD1IPAdapter", + "SD1T2IAdapter", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py new file mode 100644 index 0000000..d263afc --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/t2i_adapter.py @@ -0,0 +1,46 @@ +from typing import cast, Iterable + +from torch import Tensor + +from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, ResidualAccumulator +import refiners.fluxion.layers as fl + + +class SD1T2IAdapter(T2IAdapter[SD1UNet]): + def __init__( + self, + target: SD1UNet, + name: str, + condition_encoder: ConditionEncoder | None = None, + weights: dict[str, Tensor] | None = None, + ) -> None: + self.residual_indices = (2, 5, 8, 11) + super().__init__( + target=target, + name=name, + condition_encoder=condition_encoder or ConditionEncoder(device=target.device, dtype=target.dtype), + weights=weights, + ) + + def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter": + for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): + if n not in self.residual_indices: + continue + for t2i_layer in block.layers(layer_type=T2IFeatures): + assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" + block.insert_before_type( + ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n)) + ) + return super().inject(parent) + + def eject(self: "SD1T2IAdapter") -> None: + for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): + if n not in self.residual_indices: + continue + t2i_layers = [ + t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name + ] + assert len(t2i_layers) == 1 + block.remove(t2i_layers.pop()) + super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py index f2b20aa..775a0e1 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/__init__.py @@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter +from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter __all__ = [ @@ -9,4 +10,5 @@ __all__ = [ "DoubleTextEncoder", "StableDiffusion_XL", "SDXLIPAdapter", + "SDXLT2IAdapter", ] diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py new file mode 100644 index 0000000..65be7e4 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/t2i_adapter.py @@ -0,0 +1,57 @@ +from typing import cast, Iterable + +from torch import Tensor + +from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL +from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet +from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator +import refiners.fluxion.layers as fl + + +class SDXLT2IAdapter(T2IAdapter[SDXLUNet]): + def __init__( + self, + target: SDXLUNet, + name: str, + condition_encoder: ConditionEncoderXL | None = None, + weights: dict[str, Tensor] | None = None, + ) -> None: + self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`) + super().__init__( + target=target, + name=name, + condition_encoder=condition_encoder or ConditionEncoderXL(device=target.device, dtype=target.dtype), + weights=weights, + ) + + def inject(self: "SDXLT2IAdapter", parent: fl.Chain | None = None) -> "SDXLT2IAdapter": + def sanity_check_t2i(block: fl.Module) -> None: + for t2i_layer in block.layers(layer_type=T2IFeatures): + assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected" + + for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): + if n not in self.residual_indices: + continue + sanity_check_t2i(block) + block.insert_before_type( + ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n)) + ) + sanity_check_t2i(self.target.MiddleBlock) + # Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append + self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1)) + return super().inject(parent) + + def eject(self: "SDXLT2IAdapter") -> None: + def eject_t2i(block: fl.Module) -> None: + t2i_layers = [ + t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name + ] + assert len(t2i_layers) == 1 + block.remove(t2i_layers.pop()) + + for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)): + if n not in self.residual_indices: + continue + eject_t2i(block) + eject_t2i(self.target.MiddleBlock) + super().eject() diff --git a/src/refiners/foundationals/latent_diffusion/t2i_adapter.py b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py new file mode 100644 index 0000000..ab2ef8b --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/t2i_adapter.py @@ -0,0 +1,216 @@ +from typing import Generic, TypeVar, Any, TYPE_CHECKING + +from torch import Tensor, device as Device, dtype as DType +from torch.nn import AvgPool2d as _AvgPool2d + +from refiners.fluxion.adapters.adapter import Adapter +from refiners.fluxion.context import Contexts +from refiners.fluxion.layers.module import Module +import refiners.fluxion.layers as fl + +if TYPE_CHECKING: + from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet + from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet + +T = TypeVar("T", bound="SD1UNet | SDXLUNet") +TT2IAdapter = TypeVar("TT2IAdapter", bound="T2IAdapter[Any]") # Self (see PEP 673) + + +class Downsample2d(_AvgPool2d, Module): + def __init__(self, scale_factor: int) -> None: + _AvgPool2d.__init__(self, kernel_size=scale_factor, stride=scale_factor) + + +class ResidualBlock(fl.Residual): + def __init__( + self, + channels: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + fl.Conv2d( + in_channels=channels, out_channels=channels, kernel_size=3, padding=1, device=device, dtype=dtype + ), + fl.ReLU(), + fl.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype), + ) + + +class ResidualBlocks(fl.Chain): + def __init__( + self, + in_channels: int, + out_channels: int, + num_residual_blocks: int = 2, + downsample: bool = False, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + preproc = Downsample2d(scale_factor=2) if downsample else fl.Identity() + shortcut = ( + fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype) + if in_channels != out_channels + else fl.Identity() + ) + super().__init__( + preproc, + shortcut, + fl.Chain( + ResidualBlock(channels=out_channels, device=device, dtype=dtype) for _ in range(num_residual_blocks) + ), + ) + + +class StatefulResidualBlocks(fl.Chain): + def __init__( + self, + in_channels: int, + out_channels: int, + num_residual_blocks: int = 2, + downsample: bool = False, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + super().__init__( + ResidualBlocks( + in_channels=in_channels, + out_channels=out_channels, + num_residual_blocks=num_residual_blocks, + downsample=downsample, + device=device, + dtype=dtype, + ), + fl.SetContext(context="t2iadapter", key="features", callback=self.push), + ) + + def push(self, features: list[Tensor], x: Tensor) -> None: + features.append(x) + + +class ConditionEncoder(fl.Chain): + def __init__( + self, + in_channels: int = 3, + channels: tuple[int, int, int, int] = (320, 640, 1280, 1280), + num_residual_blocks: int = 2, + downscale_factor: int = 8, + scale: float = 1.0, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.scale = scale + super().__init__( + fl.PixelUnshuffle(downscale_factor=downscale_factor), + fl.Conv2d( + in_channels=in_channels * downscale_factor**2, + out_channels=channels[0], + kernel_size=3, + padding=1, + device=device, + dtype=dtype, + ), + StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype), + *( + StatefulResidualBlocks( + channels[i - 1], channels[i], num_residual_blocks, downsample=True, device=device, dtype=dtype + ) + for i in range(1, len(channels)) + ), + fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs), + ) + + def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]: + assert len(features) == 4 + return tuple([x * self.scale for x in features]) + + def init_context(self) -> Contexts: + return {"t2iadapter": {"features": []}} + + +class ConditionEncoderXL(ConditionEncoder, fl.Chain): + def __init__( + self, + in_channels: int = 3, + channels: tuple[int, int, int, int] = (320, 640, 1280, 1280), + num_residual_blocks: int = 2, + downscale_factor: int = 16, + scale: float = 1.0, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + self.scale = scale + fl.Chain.__init__( + self, + fl.PixelUnshuffle(downscale_factor=downscale_factor), + fl.Conv2d( + in_channels=in_channels * downscale_factor**2, + out_channels=channels[0], + kernel_size=3, + padding=1, + device=device, + dtype=dtype, + ), + StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype), + StatefulResidualBlocks(channels[0], channels[1], num_residual_blocks, device=device, dtype=dtype), + StatefulResidualBlocks( + channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype + ), + StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype), + fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs), + ) + + +class T2IFeatures(fl.Residual): + def __init__(self, name: str, index: int) -> None: + self.name = name + self.index = index + super().__init__( + fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose( + func=lambda features: features[self.index] + ) + ) + + +class T2IAdapter(Generic[T], fl.Chain, Adapter[T]): + _condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration + + def __init__( + self, + target: T, + name: str, + condition_encoder: ConditionEncoder, + weights: dict[str, Tensor] | None = None, + ) -> None: + self.name = name + if weights is not None: + condition_encoder.load_state_dict(weights) + self._condition_encoder = [condition_encoder] + + with self.setup_adapter(target): + super().__init__(target) + + def inject(self: TT2IAdapter, parent: fl.Chain | None = None) -> TT2IAdapter: + return super().inject(parent) + + def eject(self) -> None: + super().eject() + + @property + def condition_encoder(self) -> ConditionEncoder: + return self._condition_encoder[0] + + def compute_condition_features(self, condition: Tensor) -> tuple[Tensor, ...]: + return self.condition_encoder(condition) + + def set_condition_features(self, features: tuple[Tensor, ...]) -> None: + self.set_context("t2iadapter", {f"condition_features_{self.name}": features}) + + def set_scale(self, scale: float) -> None: + self.condition_encoder.scale = scale + + def init_context(self) -> Contexts: + return {"t2iadapter": {f"condition_features_{self.name}": None}} + + def structural_copy(self: "TT2IAdapter") -> "TT2IAdapter": + raise RuntimeError("T2I-Adapter cannot be copied, eject it first.")