mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add T2I-Adapter to foundationals/latent_diffusion
This commit is contained in:
parent
d72e1d3478
commit
14864857b1
|
@ -11,11 +11,13 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
|||
SD1UNet,
|
||||
SD1ControlnetAdapter,
|
||||
SD1IPAdapter,
|
||||
SD1T2IAdapter,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import (
|
||||
SDXLUNet,
|
||||
DoubleTextEncoder,
|
||||
SDXLIPAdapter,
|
||||
SDXLT2IAdapter,
|
||||
)
|
||||
|
||||
|
||||
|
@ -25,9 +27,11 @@ __all__ = [
|
|||
"SD1UNet",
|
||||
"SD1ControlnetAdapter",
|
||||
"SD1IPAdapter",
|
||||
"SD1T2IAdapter",
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"SDXLIPAdapter",
|
||||
"SDXLT2IAdapter",
|
||||
"DPMSolver",
|
||||
"Scheduler",
|
||||
"CLIPTextEncoderL",
|
||||
|
|
|
@ -5,6 +5,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
|||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.t2i_adapter import SD1T2IAdapter
|
||||
|
||||
__all__ = [
|
||||
"StableDiffusion_1",
|
||||
|
@ -12,4 +13,5 @@ __all__ = [
|
|||
"SD1UNet",
|
||||
"SD1ControlnetAdapter",
|
||||
"SD1IPAdapter",
|
||||
"SD1T2IAdapter",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
from typing import cast, Iterable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet, ResidualAccumulator
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class SD1T2IAdapter(T2IAdapter[SD1UNet]):
|
||||
def __init__(
|
||||
self,
|
||||
target: SD1UNet,
|
||||
name: str,
|
||||
condition_encoder: ConditionEncoder | None = None,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
self.residual_indices = (2, 5, 8, 11)
|
||||
super().__init__(
|
||||
target=target,
|
||||
name=name,
|
||||
condition_encoder=condition_encoder or ConditionEncoder(device=target.device, dtype=target.dtype),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def inject(self: "SD1T2IAdapter", parent: fl.Chain | None = None) -> "SD1T2IAdapter":
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||
block.insert_before_type(
|
||||
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
||||
)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self: "SD1T2IAdapter") -> None:
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
t2i_layers = [
|
||||
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
||||
]
|
||||
assert len(t2i_layers) == 1
|
||||
block.remove(t2i_layers.pop())
|
||||
super().eject()
|
|
@ -2,6 +2,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.image_prompt import SDXLIPAdapter
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.t2i_adapter import SDXLT2IAdapter
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
@ -9,4 +10,5 @@ __all__ = [
|
|||
"DoubleTextEncoder",
|
||||
"StableDiffusion_XL",
|
||||
"SDXLIPAdapter",
|
||||
"SDXLT2IAdapter",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
from typing import cast, Iterable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.t2i_adapter import T2IAdapter, T2IFeatures, ConditionEncoderXL
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import ResidualAccumulator
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
|
||||
class SDXLT2IAdapter(T2IAdapter[SDXLUNet]):
|
||||
def __init__(
|
||||
self,
|
||||
target: SDXLUNet,
|
||||
name: str,
|
||||
condition_encoder: ConditionEncoderXL | None = None,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
self.residual_indices = (3, 5, 8) # the UNet's middle block is handled separately (see `inject` and `eject`)
|
||||
super().__init__(
|
||||
target=target,
|
||||
name=name,
|
||||
condition_encoder=condition_encoder or ConditionEncoderXL(device=target.device, dtype=target.dtype),
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def inject(self: "SDXLT2IAdapter", parent: fl.Chain | None = None) -> "SDXLT2IAdapter":
|
||||
def sanity_check_t2i(block: fl.Module) -> None:
|
||||
for t2i_layer in block.layers(layer_type=T2IFeatures):
|
||||
assert t2i_layer.name != self.name, f"T2I-Adapter named {self.name} is already injected"
|
||||
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
sanity_check_t2i(block)
|
||||
block.insert_before_type(
|
||||
ResidualAccumulator, T2IFeatures(name=self.name, index=self.residual_indices.index(n))
|
||||
)
|
||||
sanity_check_t2i(self.target.MiddleBlock)
|
||||
# Special case: the MiddleBlock has no ResidualAccumulator (this is done via a subsequent layer) so just append
|
||||
self.target.MiddleBlock.append(T2IFeatures(name=self.name, index=-1))
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self: "SDXLT2IAdapter") -> None:
|
||||
def eject_t2i(block: fl.Module) -> None:
|
||||
t2i_layers = [
|
||||
t2i_layer for t2i_layer in block.layers(layer_type=T2IFeatures) if t2i_layer.name == self.name
|
||||
]
|
||||
assert len(t2i_layers) == 1
|
||||
block.remove(t2i_layers.pop())
|
||||
|
||||
for n, block in enumerate(cast(Iterable[fl.Chain], self.target.DownBlocks)):
|
||||
if n not in self.residual_indices:
|
||||
continue
|
||||
eject_t2i(block)
|
||||
eject_t2i(self.target.MiddleBlock)
|
||||
super().eject()
|
216
src/refiners/foundationals/latent_diffusion/t2i_adapter.py
Normal file
216
src/refiners/foundationals/latent_diffusion/t2i_adapter.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
from typing import Generic, TypeVar, Any, TYPE_CHECKING
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import AvgPool2d as _AvgPool2d
|
||||
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers.module import Module
|
||||
import refiners.fluxion.layers as fl
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
|
||||
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||
TT2IAdapter = TypeVar("TT2IAdapter", bound="T2IAdapter[Any]") # Self (see PEP 673)
|
||||
|
||||
|
||||
class Downsample2d(_AvgPool2d, Module):
|
||||
def __init__(self, scale_factor: int) -> None:
|
||||
_AvgPool2d.__init__(self, kernel_size=scale_factor, stride=scale_factor)
|
||||
|
||||
|
||||
class ResidualBlock(fl.Residual):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.Conv2d(
|
||||
in_channels=channels, out_channels=channels, kernel_size=3, padding=1, device=device, dtype=dtype
|
||||
),
|
||||
fl.ReLU(),
|
||||
fl.Conv2d(in_channels=channels, out_channels=channels, kernel_size=1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class ResidualBlocks(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_residual_blocks: int = 2,
|
||||
downsample: bool = False,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
preproc = Downsample2d(scale_factor=2) if downsample else fl.Identity()
|
||||
shortcut = (
|
||||
fl.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, device=device, dtype=dtype)
|
||||
if in_channels != out_channels
|
||||
else fl.Identity()
|
||||
)
|
||||
super().__init__(
|
||||
preproc,
|
||||
shortcut,
|
||||
fl.Chain(
|
||||
ResidualBlock(channels=out_channels, device=device, dtype=dtype) for _ in range(num_residual_blocks)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class StatefulResidualBlocks(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_residual_blocks: int = 2,
|
||||
downsample: bool = False,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
ResidualBlocks(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
num_residual_blocks=num_residual_blocks,
|
||||
downsample=downsample,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.SetContext(context="t2iadapter", key="features", callback=self.push),
|
||||
)
|
||||
|
||||
def push(self, features: list[Tensor], x: Tensor) -> None:
|
||||
features.append(x)
|
||||
|
||||
|
||||
class ConditionEncoder(fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
|
||||
num_residual_blocks: int = 2,
|
||||
downscale_factor: int = 8,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.scale = scale
|
||||
super().__init__(
|
||||
fl.PixelUnshuffle(downscale_factor=downscale_factor),
|
||||
fl.Conv2d(
|
||||
in_channels=in_channels * downscale_factor**2,
|
||||
out_channels=channels[0],
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
|
||||
*(
|
||||
StatefulResidualBlocks(
|
||||
channels[i - 1], channels[i], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||
)
|
||||
for i in range(1, len(channels))
|
||||
),
|
||||
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
||||
)
|
||||
|
||||
def scale_outputs(self, features: list[Tensor]) -> tuple[Tensor, ...]:
|
||||
assert len(features) == 4
|
||||
return tuple([x * self.scale for x in features])
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"t2iadapter": {"features": []}}
|
||||
|
||||
|
||||
class ConditionEncoderXL(ConditionEncoder, fl.Chain):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: tuple[int, int, int, int] = (320, 640, 1280, 1280),
|
||||
num_residual_blocks: int = 2,
|
||||
downscale_factor: int = 16,
|
||||
scale: float = 1.0,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
self.scale = scale
|
||||
fl.Chain.__init__(
|
||||
self,
|
||||
fl.PixelUnshuffle(downscale_factor=downscale_factor),
|
||||
fl.Conv2d(
|
||||
in_channels=in_channels * downscale_factor**2,
|
||||
out_channels=channels[0],
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
StatefulResidualBlocks(channels[0], channels[0], num_residual_blocks, device=device, dtype=dtype),
|
||||
StatefulResidualBlocks(channels[0], channels[1], num_residual_blocks, device=device, dtype=dtype),
|
||||
StatefulResidualBlocks(
|
||||
channels[1], channels[2], num_residual_blocks, downsample=True, device=device, dtype=dtype
|
||||
),
|
||||
StatefulResidualBlocks(channels[2], channels[3], num_residual_blocks, device=device, dtype=dtype),
|
||||
fl.UseContext(context="t2iadapter", key="features").compose(func=self.scale_outputs),
|
||||
)
|
||||
|
||||
|
||||
class T2IFeatures(fl.Residual):
|
||||
def __init__(self, name: str, index: int) -> None:
|
||||
self.name = name
|
||||
self.index = index
|
||||
super().__init__(
|
||||
fl.UseContext(context="t2iadapter", key=f"condition_features_{self.name}").compose(
|
||||
func=lambda features: features[self.index]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class T2IAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
_condition_encoder: list[ConditionEncoder] # prevent PyTorch module registration
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
name: str,
|
||||
condition_encoder: ConditionEncoder,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
self.name = name
|
||||
if weights is not None:
|
||||
condition_encoder.load_state_dict(weights)
|
||||
self._condition_encoder = [condition_encoder]
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
def inject(self: TT2IAdapter, parent: fl.Chain | None = None) -> TT2IAdapter:
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def condition_encoder(self) -> ConditionEncoder:
|
||||
return self._condition_encoder[0]
|
||||
|
||||
def compute_condition_features(self, condition: Tensor) -> tuple[Tensor, ...]:
|
||||
return self.condition_encoder(condition)
|
||||
|
||||
def set_condition_features(self, features: tuple[Tensor, ...]) -> None:
|
||||
self.set_context("t2iadapter", {f"condition_features_{self.name}": features})
|
||||
|
||||
def set_scale(self, scale: float) -> None:
|
||||
self.condition_encoder.scale = scale
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"t2iadapter": {f"condition_features_{self.name}": None}}
|
||||
|
||||
def structural_copy(self: "TT2IAdapter") -> "TT2IAdapter":
|
||||
raise RuntimeError("T2I-Adapter cannot be copied, eject it first.")
|
Loading…
Reference in a new issue