mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
implement StyleAlignedAdapter
This commit is contained in:
parent
432e32f94f
commit
efa3988638
329
src/refiners/foundationals/latent_diffusion/style_aligned.py
Normal file
329
src/refiners/foundationals/latent_diffusion/style_aligned.py
Normal file
|
@ -0,0 +1,329 @@
|
|||
from functools import cached_property
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||
|
||||
T = TypeVar("T", bound="SD1UNet | SDXLUNet")
|
||||
|
||||
|
||||
class ExtractReferenceFeatures(fl.Module):
|
||||
"""Extract the reference features from the input features.
|
||||
|
||||
Note:
|
||||
This layer expects the input features to be a concatenation of conditional and unconditional features,
|
||||
as done when using Classifier-free guidance (CFG).
|
||||
|
||||
The reference features are the first features of the conditional and unconditional input features.
|
||||
They are extracted, and repeated to match the batch size of the input features.
|
||||
|
||||
Receives:
|
||||
features (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The input features.
|
||||
|
||||
Returns:
|
||||
reference (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The reference features.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
features: Float[Tensor, "cfg_batch_size sequence_length embedding_dim"],
|
||||
) -> Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]:
|
||||
cfg_batch_size = features.shape[0]
|
||||
batch_size = cfg_batch_size // 2
|
||||
|
||||
# split the cfg
|
||||
features_cond, features_uncond = torch.chunk(features, 2, dim=0)
|
||||
# -> 2 x (batch_size, sequence_length, embedding_dim)
|
||||
|
||||
# extract the reference features
|
||||
features_ref = torch.stack(
|
||||
(
|
||||
features_cond[0], # (sequence_length, embedding_dim)
|
||||
features_uncond[0], # (sequence_length, embedding_dim)
|
||||
),
|
||||
) # -> (2, sequence_length, embedding_dim)
|
||||
|
||||
# repeat the reference features to match the batch size
|
||||
features_ref = features_ref.repeat_interleave(batch_size, dim=0)
|
||||
# -> (cfg_batch_size, sequence_length, embedding_dim)
|
||||
|
||||
return features_ref
|
||||
|
||||
|
||||
class AdaIN(fl.Module):
|
||||
"""Apply Adaptive Instance Normalization (AdaIN) to the target features.
|
||||
|
||||
See [[arXiv:1703.06868] Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization](https://arxiv.org/abs/1703.06868) for more details.
|
||||
|
||||
Receives:
|
||||
reference (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The reference features.
|
||||
targets (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The target features.
|
||||
|
||||
Returns:
|
||||
reference (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The reference features (unchanged).
|
||||
targets (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The target features, renormalized.
|
||||
"""
|
||||
|
||||
def __init__(self, epsilon: float = 1e-8) -> None:
|
||||
"""Initialize the AdaIN module.
|
||||
|
||||
Args:
|
||||
epsilon: A small value to avoid division by zero.
|
||||
"""
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(
|
||||
self,
|
||||
targets: Float[Tensor, "cfg_batch_size sequence_length embedding_dim"],
|
||||
reference: Float[Tensor, "cfg_batch_size sequence_length embedding_dim"],
|
||||
) -> tuple[
|
||||
Float[Tensor, "cfg_batch_size sequence_length embedding_dim"], # targets (renormalized)
|
||||
Float[Tensor, "cfg_batch_size sequence_length embedding_dim"], # reference (unchanged)
|
||||
]:
|
||||
targets_mean = torch.mean(targets, dim=-2, keepdim=True)
|
||||
targets_std = torch.std(targets, dim=-2, keepdim=True)
|
||||
targets_normalized = (targets - targets_mean) / (targets_std + self.epsilon)
|
||||
|
||||
reference_mean = torch.mean(reference, dim=-2, keepdim=True)
|
||||
reference_std = torch.std(reference, dim=-2, keepdim=True)
|
||||
targets_renormalized = targets_normalized * reference_std + reference_mean
|
||||
|
||||
return (
|
||||
targets_renormalized,
|
||||
reference,
|
||||
)
|
||||
|
||||
|
||||
class ScaleReferenceFeatures(fl.Module):
|
||||
"""Scale the reference features.
|
||||
|
||||
Note:
|
||||
This layer expects the input features to be a concatenation of conditional and unconditional features,
|
||||
as done when using Classifier-free guidance (CFG).
|
||||
|
||||
This layer scales the reference features which will later be used (in the attention dot product) with the target features.
|
||||
|
||||
Receives:
|
||||
features (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The input reference features.
|
||||
|
||||
Returns:
|
||||
features (Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]): The rescaled reference features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
"""Initialize the ScaleReferenceFeatures module.
|
||||
|
||||
Args:
|
||||
scale: The scaling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
|
||||
def forward(
|
||||
self,
|
||||
features: Float[Tensor, "cfg_batch_size sequence_length embedding_dim"],
|
||||
) -> Float[Tensor, "cfg_batch_size sequence_length embedding_dim"]:
|
||||
cfg_batch_size = features.shape[0]
|
||||
batch_size = cfg_batch_size // 2
|
||||
|
||||
# clone the features
|
||||
# needed because all the following operations are in-place
|
||||
features = features.clone()
|
||||
|
||||
# "stack" the cfg
|
||||
features_cfg_stack = features.reshape(2, batch_size, *features.shape[1:])
|
||||
|
||||
# scale the reference features which will later be used (in the attention dot product) with the target features
|
||||
features_cfg_stack[:, 1:] *= self.scale
|
||||
|
||||
# "unstack" the cfg
|
||||
features = features_cfg_stack.reshape(features.shape)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class StyleAligned(fl.Chain):
|
||||
"""StyleAligned module.
|
||||
|
||||
This layer encapsulates the logic of the StyleAligned method,
|
||||
as described in [[arXiv:2312.02133] Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133).
|
||||
|
||||
See also <https://blog.finegrain.ai/posts/implementing-style-aligned/>.
|
||||
|
||||
Receives:
|
||||
features (Float[Tensor, "cfg_batch_size sequence_length_in embedding_dim"]): The input features.
|
||||
|
||||
Returns:
|
||||
shared_features (Float[Tensor, "cfg_batch_size sequence_length_out embedding_dim"]): The transformed features.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adain: bool,
|
||||
concatenate: bool,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
"""Initialize the StyleAligned module.
|
||||
|
||||
Args:
|
||||
adain: Whether to apply Adaptive Instance Normalization to the target features.
|
||||
scale: The scaling factor for the reference features.
|
||||
concatenate: Whether to concatenate the reference and target features.
|
||||
"""
|
||||
super().__init__(
|
||||
# (features): (cfg_batch_size sequence_length embedding_dim)
|
||||
fl.Parallel(
|
||||
fl.Identity(),
|
||||
ExtractReferenceFeatures(),
|
||||
),
|
||||
# (targets, reference)
|
||||
AdaIN(),
|
||||
# (targets_renormalized, reference)
|
||||
fl.Distribute(
|
||||
fl.Identity(),
|
||||
ScaleReferenceFeatures(scale=scale),
|
||||
),
|
||||
# (targets_renormalized, reference_scaled)
|
||||
fl.Concatenate(
|
||||
fl.GetArg(index=0), # targets
|
||||
fl.GetArg(index=1), # reference
|
||||
dim=-2, # sequence_length
|
||||
),
|
||||
# (features_with_shared_reference)
|
||||
)
|
||||
|
||||
if not adain:
|
||||
adain_module = self.ensure_find(AdaIN)
|
||||
self.remove(adain_module)
|
||||
|
||||
if not concatenate:
|
||||
concatenate_module = self.ensure_find(fl.Concatenate)
|
||||
self.replace(
|
||||
old_module=concatenate_module,
|
||||
new_module=fl.GetArg(index=0), # targets
|
||||
)
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scaling factor for the reference features."""
|
||||
scale_reference = self.ensure_find(ScaleReferenceFeatures)
|
||||
return scale_reference.scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, scale: float) -> None:
|
||||
scale_reference = self.ensure_find(ScaleReferenceFeatures)
|
||||
scale_reference.scale = scale
|
||||
|
||||
|
||||
class SharedSelfAttentionAdapter(fl.Chain, Adapter[fl.SelfAttention]):
|
||||
"""Upgrades a `SelfAttention` layer into a `SharedSelfAttention` layer.
|
||||
|
||||
This adapter inserts 3 `StyleAligned` modules right after
|
||||
the original Q, K, V `Linear`-s (wrapped inside a `fl.Distribute`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: fl.SelfAttention,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
self._style_aligned_layers = [
|
||||
StyleAligned( # Query
|
||||
adain=True,
|
||||
concatenate=False,
|
||||
scale=scale,
|
||||
),
|
||||
StyleAligned( # Key
|
||||
adain=True,
|
||||
concatenate=True,
|
||||
scale=scale,
|
||||
),
|
||||
StyleAligned( # Value
|
||||
adain=False,
|
||||
concatenate=True,
|
||||
scale=scale,
|
||||
),
|
||||
]
|
||||
|
||||
@cached_property
|
||||
def style_aligned_layers(self) -> fl.Distribute:
|
||||
return fl.Distribute(*self._style_aligned_layers)
|
||||
|
||||
def inject(self, parent: fl.Chain | None = None) -> "SharedSelfAttentionAdapter":
|
||||
self.target.insert_before_type(
|
||||
module_type=fl.ScaledDotProductAttention,
|
||||
new_module=self.style_aligned_layers,
|
||||
)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
self.target.remove(self.style_aligned_layers)
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
return self.style_aligned_layers.layer(0, StyleAligned).scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, scale: float) -> None:
|
||||
for style_aligned_module in self.style_aligned_layers:
|
||||
style_aligned_module.scale = scale
|
||||
|
||||
|
||||
class StyleAlignedAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
"""Upgrade each `SelfAttention` layer of a UNet into a `SharedSelfAttention` layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: T,
|
||||
scale: float = 1.0,
|
||||
) -> None:
|
||||
"""Initialize the StyleAlignedAdapter.
|
||||
|
||||
Args:
|
||||
target: The target module.
|
||||
scale: The scaling factor for the reference features.
|
||||
"""
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
# create a SharedSelfAttentionAdapter for each SelfAttention module
|
||||
self.shared_self_attention_adapters = tuple(
|
||||
SharedSelfAttentionAdapter(
|
||||
target=self_attention,
|
||||
scale=scale,
|
||||
)
|
||||
for self_attention in self.target.layers(fl.SelfAttention)
|
||||
)
|
||||
|
||||
def inject(self, parent: fl.Chain | None = None) -> "StyleAlignedAdapter[T]":
|
||||
for shared_self_attention_adapter in self.shared_self_attention_adapters:
|
||||
shared_self_attention_adapter.inject()
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
for shared_self_attention_adapter in self.shared_self_attention_adapters:
|
||||
shared_self_attention_adapter.eject()
|
||||
super().eject()
|
||||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scaling factor for the reference features."""
|
||||
return self.shared_self_attention_adapters[0].scale
|
||||
|
||||
@scale.setter
|
||||
def scale(self, scale: float) -> None:
|
||||
for shared_self_attention_adapter in self.shared_self_attention_adapters:
|
||||
shared_self_attention_adapter.scale = scale
|
Loading…
Reference in a new issue