mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add T2I-Adapter to foundationals/latent_diffusion
This commit is contained in:
parent
d72e1d3478
commit
14864857b1
|
@ -11,11 +11,13 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
SD1UNet,
|
SD1UNet,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
SD1IPAdapter,
|
SD1IPAdapter,
|
||||||
|
SD1T2IAdapter,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||||
SDXLUNet,
|
SDXLUNet,
|
||||||
DoubleTextEncoder,
|
DoubleTextEncoder,
|
||||||
SDXLIPAdapter,
|
SDXLIPAdapter,
|
||||||
|
SDXLT2IAdapter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,9 +27,11 @@ __all__ = [
|
||||||
"SD1UNet",
|
"SD1UNet",
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
|
"SD1T2IAdapter",
|
||||||
"SDXLUNet",
|
"SDXLUNet",
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
|
"SDXLT2IAdapter",
|
||||||
"DPMSolver",
|
"DPMSolver",
|
||||||
"Scheduler",
|
"Scheduler",
|
||||||
"CLIPTextEncoderL",
|
"CLIPTextEncoderL",
|
||||||
|
|
|
@ -5,6 +5,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"StableDiffusion_1",
|
"StableDiffusion_1",
|
||||||
|
@ -12,4 +13,5 @@ __all__ = [
|
||||||
"SD1UNet",
|
"SD1UNet",
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
|
"SD1T2IAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import cast, Iterable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, ResidualAccumulator
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SD1UNet,
|
||||||
|
name: str,
|
||||||
|
condition_encoder: ConditionEncoder | None = None,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.residual_indices = (2, 5, 8, 11)
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
name=name,
|
||||||
|
condition_encoder=condition_encoder or ConditionEncoder(device=target.device, dtype=target.dtype),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||||
|
if n not in self.residual_indices:
|
||||||
|
continue
|
||||||
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
|
block.insert_before_type(
|
||||||
|
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
||||||
|
)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self: "SD1T2IAdapter") -> None:
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||||
|
if n not in self.residual_indices:
|
||||||
|
continue
|
||||||
|
t2i_layers = [
|
||||||
|
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
||||||
|
]
|
||||||
|
assert len(t2i_layers) == 1
|
||||||
|
block.remove(t2i_layers.pop())
|
||||||
|
super().eject()
|
|
@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -9,4 +10,5 @@ __all__ = [
|
||||||
"DoubleTextEncoder",
|
"DoubleTextEncoder",
|
||||||
"StableDiffusion_XL",
|
"StableDiffusion_XL",
|
||||||
"SDXLIPAdapter",
|
"SDXLIPAdapter",
|
||||||
|
"SDXLT2IAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
from typing import cast, Iterable
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SDXLUNet,
|
||||||
|
name: str,
|
||||||
|
condition_encoder: ConditionEncoderXL | None = None,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
|
||||||
|
super().__init__(
|
||||||
|
target=target,
|
||||||
|
name=name,
|
||||||
|
condition_encoder=condition_encoder or ConditionEncoderXL(device=target.device, dtype=target.dtype),
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
def inject(self: "SDXLT2IAdapter", parent: fl.Chain | None = None) -> "SDXLT2IAdapter":
|
||||||
|
def sanity_check_t2i(block: fl.Module) -> None:
|
||||||
|
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||||
|
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||||
|
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||||
|
if n not in self.residual_indices:
|
||||||
|
continue
|
||||||
|
sanity_check_t2i(block)
|
||||||
|
block.insert_before_type(
|
||||||
|
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
||||||
|
)
|
||||||
|
sanity_check_t2i(self.target.MiddleBlock)
|
||||||
|
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
||||||
|
self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1))
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self: "SDXLT2IAdapter") -> None:
|
||||||
|
def eject_t2i(block: fl.Module) -> None:
|
||||||
|
t2i_layers = [
|
||||||
|
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
||||||
|
]
|
||||||
|
assert len(t2i_layers) == 1
|
||||||
|
block.remove(t2i_layers.pop())
|
||||||
|
|
||||||
|
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||||
|
if n not in self.residual_indices:
|
||||||
|
continue
|
||||||
|
eject_t2i(block)
|
||||||
|
eject_t2i(self.target.MiddleBlock)
|
||||||
|
super().eject()
|
216
src/refiners/foundationals/latent_diffusion/t2i_adapter.py
Normal file
216
src/refiners/foundationals/latent_diffusion/t2i_adapter.py
Normal file
|
@ -0,0 +1,216 @@
|
||||||
|
from typing import Generic, TypeVar, Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn import AvgPool2d as _AvgPool2d
|
||||||
|
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.fluxion.layers.module import Module
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||||
|
TT2IAdapter = TypeVar("TT2IAdapter", bound="T2IAdapter[Any]") # Self (see PEP 673)
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample2d(_AvgPool2d, Module):
|
||||||
|
def __init__(self, scale_factor: int) -> None:
|
||||||
|
_AvgPool2d.__init__(self, kernel_size=scale_factor, stride=scale_factor)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=channels, out_channels=channels, kernel_size=3, padding=1, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
fl.ReLU(),
|
||||||
|
fl.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downsample: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
preproc = Downsample2d(scale_factor=2) if downsample else fl.Identity()
|
||||||
|
shortcut = (
|
||||||
|
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||||
|
if in_channels != out_channels
|
||||||
|
else fl.Identity()
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
preproc,
|
||||||
|
shortcut,
|
||||||
|
fl.Chain(
|
||||||
|
ResidualBlock(channels=out_channels, device=device, dtype=dtype) for _ in range(num_residual_blocks)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulResidualBlocks(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downsample: bool = False,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
ResidualBlocks(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
num_residual_blocks=num_residual_blocks,
|
||||||
|
downsample=downsample,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
fl.SetContext(context="t2iadapter", key="features", callback=self.push),
|
||||||
|
)
|
||||||
|
|
||||||
|
def push(self, features: list[Tensor], x: Tensor) -> None:
|
||||||
|
features.append(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoder(fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downscale_factor: int = 8,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
super().__init__(
|
||||||
|
fl.PixelUnshuffle(downscale_factor=downscale_factor),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels * downscale_factor**2,
|
||||||
|
out_channels=channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
*(
|
||||||
|
StatefulResidualBlocks(
|
||||||
|
channels[i - 1], channels[i], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
for i in range(1, len(channels))
|
||||||
|
),
|
||||||
|
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]:
|
||||||
|
assert len(features) == 4
|
||||||
|
return tuple([x * self.scale for x in features])
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"t2iadapter": {"features": []}}
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionEncoderXL(ConditionEncoder, fl.Chain):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 3,
|
||||||
|
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
|
||||||
|
num_residual_blocks: int = 2,
|
||||||
|
downscale_factor: int = 16,
|
||||||
|
scale: float = 1.0,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.scale = scale
|
||||||
|
fl.Chain.__init__(
|
||||||
|
self,
|
||||||
|
fl.PixelUnshuffle(downscale_factor=downscale_factor),
|
||||||
|
fl.Conv2d(
|
||||||
|
in_channels=in_channels * downscale_factor**2,
|
||||||
|
out_channels=channels[0],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
StatefulResidualBlocks(channels[0], channels[1], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
StatefulResidualBlocks(
|
||||||
|
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||||
|
),
|
||||||
|
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
|
||||||
|
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T2IFeatures(fl.Residual):
|
||||||
|
def __init__(self, name: str, index: int) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.index = index
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
|
||||||
|
func=lambda features: features[self.index]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: T,
|
||||||
|
name: str,
|
||||||
|
condition_encoder: ConditionEncoder,
|
||||||
|
weights: dict[str, Tensor] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.name = name
|
||||||
|
if weights is not None:
|
||||||
|
condition_encoder.load_state_dict(weights)
|
||||||
|
self._condition_encoder = [condition_encoder]
|
||||||
|
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def inject(self: TT2IAdapter, parent: fl.Chain | None = None) -> TT2IAdapter:
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
super().eject()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition_encoder(self) -> ConditionEncoder:
|
||||||
|
return self._condition_encoder[0]
|
||||||
|
|
||||||
|
def compute_condition_features(self, condition: Tensor) -> tuple[Tensor, ...]:
|
||||||
|
return self.condition_encoder(condition)
|
||||||
|
|
||||||
|
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
|
||||||
|
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
||||||
|
|
||||||
|
def set_scale(self, scale: float) -> None:
|
||||||
|
self.condition_encoder.scale = scale
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
||||||
|
|
||||||
|
def structural_copy(self: "TT2IAdapter") -> "TT2IAdapter":
|
||||||
|
raise RuntimeError("T2I-Adapter cannot be copied, eject it first.")
|
Loading…
Reference in a new issue