add LcmAdapter

This adds support for the condition scale embedding.
Also updates the UNet converter to support LCM.
This commit is contained in:
Pierre Chapuis 2024-02-15 18:58:23 +01:00
parent c8c6294550
commit f8d55ccb20
2 changed files with 108 additions and 4 deletions

View file

@ -1,5 +1,6 @@
import argparse
from pathlib import Path
from typing import Any
import torch
from diffusers import UNet2DConditionModel # type: ignore
@ -7,6 +8,7 @@ from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
class Args(argparse.Namespace):
@ -28,9 +30,16 @@ def setup_converter(args: Args) -> ModelConverter:
source_in_channels: int = source.config.in_channels # type: ignore
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
target = (
SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels)
)
source_is_lcm: bool = source.config.time_cond_proj_dim is not None
if source_has_time_ids:
target = SDXLUNet(in_channels=source_in_channels)
else:
target = SD1UNet(in_channels=source_in_channels)
if source_is_lcm:
assert isinstance(target, SDXLUNet)
LcmAdapter(target=target).inject()
x = torch.randn(1, source_in_channels, 32, 32)
timestep = torch.tensor(data=[0])
@ -45,9 +54,16 @@ def setup_converter(args: Args) -> ModelConverter:
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
target_args = (x,)
source_kwargs: dict[str, Any] = {}
if source_has_time_ids:
source_kwargs["added_cond_kwargs"] = added_cond_kwargs
if source_is_lcm:
source_kwargs["timestep_cond"] = torch.randn(1, source.config.time_cond_proj_dim)
source_args = {
"positional": (x, timestep, clip_text_embeddings),
"keyword": {"added_cond_kwargs": added_cond_kwargs} if source_has_time_ids else {},
"keyword": source_kwargs,
}
converter = ModelConverter(

View file

@ -0,0 +1,88 @@
import math
import torch
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
def compute_sinusoidal_embedding(
x: torch.Tensor,
embedding_dim: int,
) -> torch.Tensor:
# Differences from compute_sinusoidal_embedding in RangeAdapter:
# - we concat [sin, cos], it does the opposite ([cos, sin])
# - we divide the exponent by half_dim - 1, it divides by half_dim
half_dim = embedding_dim // 2
# Note: it is important that this computation is done in float32.
# The result can be cast to lower precision later if necessary.
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
exponent /= half_dim - 1
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
embedding = torch.cat([torch.sin(embedding), torch.cos(embedding)], dim=-1)
assert embedding.shape == (x.shape[0], embedding_dim)
return embedding
class ResidualBlock(fl.Residual):
def __init__(
self,
in_channels: int,
out_channels: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
super().__init__(
fl.UseContext("lcm", "condition_scale_embedding"),
fl.Converter(),
fl.Linear(in_features=in_channels, out_features=out_channels, bias=False, device=device, dtype=dtype),
)
class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
def __init__(
self,
target: SDXLUNet,
condition_scale_embedding_dim: int = 256,
condition_scale: float = 7.5,
) -> None:
assert condition_scale_embedding_dim % 2 == 0
self.condition_scale_embedding_dim = condition_scale_embedding_dim
self.condition_scale = condition_scale
with self.setup_adapter(target):
super().__init__(target)
def init_context(self) -> Contexts:
return {"lcm": {"condition_scale_embedding": self.sinusoidal_embedding}}
@property
def sinusoidal_embedding(self) -> torch.Tensor:
return compute_sinusoidal_embedding(
torch.tensor([(self.condition_scale - 1) * 1000], device=self.device),
embedding_dim=self.condition_scale_embedding_dim,
)
def set_condition_scale(self, scale: float) -> None:
self.condition_scale = scale
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter":
ra = self.target.ensure_find(RangeEncoder)
block = ResidualBlock(
in_channels=self.condition_scale_embedding_dim,
out_channels=ra.sinusoidal_embedding_dim,
device=self.target.device,
dtype=self.target.dtype,
)
ra.insert_before_type(fl.Linear, block)
return super().inject(parent)
def eject(self) -> None:
ra = self.target.ensure_find(RangeEncoder)
ra.remove(ra.ensure_find(ResidualBlock))
super().eject()