add T2I-Adapter to foundationals/latent_diffusion

This commit is contained in:
Cédric Deltheil 2023-09-24 15:44:20 +02:00 committed by Cédric Deltheil
parent d72e1d3478
commit 14864857b1
6 changed files with 327 additions and 0 deletions

View file

@ -11,11 +11,13 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1UNet,
SD1ControlnetAdapter,
SD1IPAdapter,
SD1T2IAdapter,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
SDXLUNet,
DoubleTextEncoder,
SDXLIPAdapter,
SDXLT2IAdapter,
)
@ -25,9 +27,11 @@ __all__ = [
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",
"SD1T2IAdapter",
"SDXLUNet",
"DoubleTextEncoder",
"SDXLIPAdapter",
"SDXLT2IAdapter",
"DPMSolver",
"Scheduler",
"CLIPTextEncoderL",

View file

@ -5,6 +5,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter
__all__ = [
"StableDiffusion_1",
@ -12,4 +13,5 @@ __all__ = [
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",
"SD1T2IAdapter",
]

View file

@ -0,0 +1,46 @@
from typing import cast, Iterable
from torch import Tensor
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, ResidualAccumulator
import refiners.fluxion.layers as fl
class SD1T2IAdapter(T2IAdapter[SD1UNet]):
def __init__(
self,
target: SD1UNet,
name: str,
condition_encoder: ConditionEncoder | None = None,
weights: dict[str, Tensor] | None = None,
) -> None:
self.residual_indices = (2, 5, 8, 11)
super().__init__(
target=target,
name=name,
condition_encoder=condition_encoder or ConditionEncoder(device=target.device, dtype=target.dtype),
weights=weights,
)
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
block.insert_before_type(
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
)
return super().inject(parent)
def eject(self: "SD1T2IAdapter") -> None:
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
t2i_layers = [
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
]
assert len(t2i_layers) == 1
block.remove(t2i_layers.pop())
super().eject()

View file

@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
__all__ = [
@ -9,4 +10,5 @@ __all__ = [
"DoubleTextEncoder",
"StableDiffusion_XL",
"SDXLIPAdapter",
"SDXLT2IAdapter",
]

View file

@ -0,0 +1,57 @@
from typing import cast, Iterable
from torch import Tensor
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
import refiners.fluxion.layers as fl
class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
def __init__(
self,
target: SDXLUNet,
name: str,
condition_encoder: ConditionEncoderXL | None = None,
weights: dict[str, Tensor] | None = None,
) -> None:
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
super().__init__(
target=target,
name=name,
condition_encoder=condition_encoder or ConditionEncoderXL(device=target.device, dtype=target.dtype),
weights=weights,
)
def inject(self: "SDXLT2IAdapter", parent: fl.Chain | None = None) -> "SDXLT2IAdapter":
def sanity_check_t2i(block: fl.Module) -> None:
for t2i_layer in block.layers(layer_type=T2IFeatures):
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
sanity_check_t2i(block)
block.insert_before_type(
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
)
sanity_check_t2i(self.target.MiddleBlock)
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1))
return super().inject(parent)
def eject(self: "SDXLT2IAdapter") -> None:
def eject_t2i(block: fl.Module) -> None:
t2i_layers = [
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
]
assert len(t2i_layers) == 1
block.remove(t2i_layers.pop())
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
if n not in self.residual_indices:
continue
eject_t2i(block)
eject_t2i(self.target.MiddleBlock)
super().eject()

View file

@ -0,0 +1,216 @@
from typing import Generic, TypeVar, Any, TYPE_CHECKING
from torch import Tensor, device as Device, dtype as DType
from torch.nn import AvgPool2d as _AvgPool2d
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.module import Module
import refiners.fluxion.layers as fl
if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
TT2IAdapter = TypeVar("TT2IAdapter", bound="T2IAdapter[Any]") # Self (see PEP 673)
class Downsample2d(_AvgPool2d, Module):
def __init__(self, scale_factor: int) -> None:
_AvgPool2d.__init__(self, kernel_size=scale_factor, stride=scale_factor)
class ResidualBlock(fl.Residual):
def __init__(
self,
channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Conv2d(
in_channels=channels, out_channels=channels, kernel_size=3, padding=1, device=device, dtype=dtype
),
fl.ReLU(),
fl.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
)
class ResidualBlocks(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
num_residual_blocks: int = 2,
downsample: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
preproc = Downsample2d(scale_factor=2) if downsample else fl.Identity()
shortcut = (
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
if in_channels != out_channels
else fl.Identity()
)
super().__init__(
preproc,
shortcut,
fl.Chain(
ResidualBlock(channels=out_channels, device=device, dtype=dtype) for _ in range(num_residual_blocks)
),
)
class StatefulResidualBlocks(fl.Chain):
def __init__(
self,
in_channels: int,
out_channels: int,
num_residual_blocks: int = 2,
downsample: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
super().__init__(
ResidualBlocks(
in_channels=in_channels,
out_channels=out_channels,
num_residual_blocks=num_residual_blocks,
downsample=downsample,
device=device,
dtype=dtype,
),
fl.SetContext(context="t2iadapter", key="features", callback=self.push),
)
def push(self, features: list[Tensor], x: Tensor) -> None:
features.append(x)
class ConditionEncoder(fl.Chain):
def __init__(
self,
in_channels: int = 3,
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
num_residual_blocks: int = 2,
downscale_factor: int = 8,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.scale = scale
super().__init__(
fl.PixelUnshuffle(downscale_factor=downscale_factor),
fl.Conv2d(
in_channels=in_channels * downscale_factor**2,
out_channels=channels[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
*(
StatefulResidualBlocks(
channels[i - 1], channels[i], num_residual_blocks, downsample=True, device=device, dtype=dtype
)
for i in range(1, len(channels))
),
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
)
def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]:
assert len(features) == 4
return tuple([x * self.scale for x in features])
def init_context(self) -> Contexts:
return {"t2iadapter": {"features": []}}
class ConditionEncoderXL(ConditionEncoder, fl.Chain):
def __init__(
self,
in_channels: int = 3,
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
num_residual_blocks: int = 2,
downscale_factor: int = 16,
scale: float = 1.0,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.scale = scale
fl.Chain.__init__(
self,
fl.PixelUnshuffle(downscale_factor=downscale_factor),
fl.Conv2d(
in_channels=in_channels * downscale_factor**2,
out_channels=channels[0],
kernel_size=3,
padding=1,
device=device,
dtype=dtype,
),
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
StatefulResidualBlocks(channels[0], channels[1], num_residual_blocks, device=device, dtype=dtype),
StatefulResidualBlocks(
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
),
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
)
class T2IFeatures(fl.Residual):
def __init__(self, name: str, index: int) -> None:
self.name = name
self.index = index
super().__init__(
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
func=lambda features: features[self.index]
)
)
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
def __init__(
self,
target: T,
name: str,
condition_encoder: ConditionEncoder,
weights: dict[str, Tensor] | None = None,
) -> None:
self.name = name
if weights is not None:
condition_encoder.load_state_dict(weights)
self._condition_encoder = [condition_encoder]
with self.setup_adapter(target):
super().__init__(target)
def inject(self: TT2IAdapter, parent: fl.Chain | None = None) -> TT2IAdapter:
return super().inject(parent)
def eject(self) -> None:
super().eject()
@property
def condition_encoder(self) -> ConditionEncoder:
return self._condition_encoder[0]
def compute_condition_features(self, condition: Tensor) -> tuple[Tensor, ...]:
return self.condition_encoder(condition)
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
def set_scale(self, scale: float) -> None:
self.condition_encoder.scale = scale
def init_context(self) -> Contexts:
return {"t2iadapter": {f"condition_features_{self.name}": None}}
def structural_copy(self: "TT2IAdapter") -> "TT2IAdapter":
raise RuntimeError("T2I-Adapter cannot be copied, eject it first.")