mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
add LcmAdapter
This adds support for the condition scale embedding. Also updates the UNet converter to support LCM.
This commit is contained in:
parent
c8c6294550
commit
f8d55ccb20
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel # type: ignore
|
||||
|
@ -7,6 +8,7 @@ from torch import nn
|
|||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
|
||||
|
||||
|
||||
class Args(argparse.Namespace):
|
||||
|
@ -28,9 +30,16 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
source_in_channels: int = source.config.in_channels # type: ignore
|
||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||
target = (
|
||||
SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels)
|
||||
)
|
||||
source_is_lcm: bool = source.config.time_cond_proj_dim is not None
|
||||
|
||||
if source_has_time_ids:
|
||||
target = SDXLUNet(in_channels=source_in_channels)
|
||||
else:
|
||||
target = SD1UNet(in_channels=source_in_channels)
|
||||
|
||||
if source_is_lcm:
|
||||
assert isinstance(target, SDXLUNet)
|
||||
LcmAdapter(target=target).inject()
|
||||
|
||||
x = torch.randn(1, source_in_channels, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
|
@ -45,9 +54,16 @@ def setup_converter(args: Args) -> ModelConverter:
|
|||
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||
|
||||
target_args = (x,)
|
||||
|
||||
source_kwargs: dict[str, Any] = {}
|
||||
if source_has_time_ids:
|
||||
source_kwargs["added_cond_kwargs"] = added_cond_kwargs
|
||||
if source_is_lcm:
|
||||
source_kwargs["timestep_cond"] = torch.randn(1, source.config.time_cond_proj_dim)
|
||||
|
||||
source_args = {
|
||||
"positional": (x, timestep, clip_text_embeddings),
|
||||
"keyword": {"added_cond_kwargs": added_cond_kwargs} if source_has_time_ids else {},
|
||||
"keyword": source_kwargs,
|
||||
}
|
||||
|
||||
converter = ModelConverter(
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
|
||||
|
||||
def compute_sinusoidal_embedding(
|
||||
x: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
) -> torch.Tensor:
|
||||
# Differences from compute_sinusoidal_embedding in RangeAdapter:
|
||||
# - we concat [sin, cos], it does the opposite ([cos, sin])
|
||||
# - we divide the exponent by half_dim - 1, it divides by half_dim
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
# Note: it is important that this computation is done in float32.
|
||||
# The result can be cast to lower precision later if necessary.
|
||||
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
|
||||
exponent /= half_dim - 1
|
||||
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
|
||||
embedding = torch.cat([torch.sin(embedding), torch.cos(embedding)], dim=-1)
|
||||
|
||||
assert embedding.shape == (x.shape[0], embedding_dim)
|
||||
return embedding
|
||||
|
||||
|
||||
class ResidualBlock(fl.Residual):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
device: torch.device | str | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
fl.UseContext("lcm", "condition_scale_embedding"),
|
||||
fl.Converter(),
|
||||
fl.Linear(in_features=in_channels, out_features=out_channels, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
|
||||
class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
||||
def __init__(
|
||||
self,
|
||||
target: SDXLUNet,
|
||||
condition_scale_embedding_dim: int = 256,
|
||||
condition_scale: float = 7.5,
|
||||
) -> None:
|
||||
assert condition_scale_embedding_dim % 2 == 0
|
||||
self.condition_scale_embedding_dim = condition_scale_embedding_dim
|
||||
self.condition_scale = condition_scale
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
def init_context(self) -> Contexts:
|
||||
return {"lcm": {"condition_scale_embedding": self.sinusoidal_embedding}}
|
||||
|
||||
@property
|
||||
def sinusoidal_embedding(self) -> torch.Tensor:
|
||||
return compute_sinusoidal_embedding(
|
||||
torch.tensor([(self.condition_scale - 1) * 1000], device=self.device),
|
||||
embedding_dim=self.condition_scale_embedding_dim,
|
||||
)
|
||||
|
||||
def set_condition_scale(self, scale: float) -> None:
|
||||
self.condition_scale = scale
|
||||
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
|
||||
|
||||
def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter":
|
||||
ra = self.target.ensure_find(RangeEncoder)
|
||||
block = ResidualBlock(
|
||||
in_channels=self.condition_scale_embedding_dim,
|
||||
out_channels=ra.sinusoidal_embedding_dim,
|
||||
device=self.target.device,
|
||||
dtype=self.target.dtype,
|
||||
)
|
||||
ra.insert_before_type(fl.Linear, block)
|
||||
return super().inject(parent)
|
||||
|
||||
def eject(self) -> None:
|
||||
ra = self.target.ensure_find(RangeEncoder)
|
||||
ra.remove(ra.ensure_find(ResidualBlock))
|
||||
super().eject()
|
Loading…
Reference in a new issue