mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add LcmAdapter
This adds support for the condition scale embedding. Also updates the UNet converter to support LCM.
This commit is contained in:
parent
c8c6294550
commit
f8d55ccb20
|
@ -1,5 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import UNet2DConditionModel # type: ignore
|
from diffusers import UNet2DConditionModel # type: ignore
|
||||||
|
@ -7,6 +8,7 @@ from torch import nn
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.lcm import LcmAdapter
|
||||||
|
|
||||||
|
|
||||||
class Args(argparse.Namespace):
|
class Args(argparse.Namespace):
|
||||||
|
@ -28,9 +30,16 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
source_in_channels: int = source.config.in_channels # type: ignore
|
source_in_channels: int = source.config.in_channels # type: ignore
|
||||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||||
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
source_has_time_ids: bool = source.config.addition_embed_type == "text_time" # type: ignore
|
||||||
target = (
|
source_is_lcm: bool = source.config.time_cond_proj_dim is not None
|
||||||
SDXLUNet(in_channels=source_in_channels) if source_has_time_ids else SD1UNet(in_channels=source_in_channels)
|
|
||||||
)
|
if source_has_time_ids:
|
||||||
|
target = SDXLUNet(in_channels=source_in_channels)
|
||||||
|
else:
|
||||||
|
target = SD1UNet(in_channels=source_in_channels)
|
||||||
|
|
||||||
|
if source_is_lcm:
|
||||||
|
assert isinstance(target, SDXLUNet)
|
||||||
|
LcmAdapter(target=target).inject()
|
||||||
|
|
||||||
x = torch.randn(1, source_in_channels, 32, 32)
|
x = torch.randn(1, source_in_channels, 32, 32)
|
||||||
timestep = torch.tensor(data=[0])
|
timestep = torch.tensor(data=[0])
|
||||||
|
@ -45,9 +54,16 @@ def setup_converter(args: Args) -> ModelConverter:
|
||||||
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
target.set_pooled_text_embedding(pooled_text_embedding=added_cond_kwargs["text_embeds"])
|
||||||
|
|
||||||
target_args = (x,)
|
target_args = (x,)
|
||||||
|
|
||||||
|
source_kwargs: dict[str, Any] = {}
|
||||||
|
if source_has_time_ids:
|
||||||
|
source_kwargs["added_cond_kwargs"] = added_cond_kwargs
|
||||||
|
if source_is_lcm:
|
||||||
|
source_kwargs["timestep_cond"] = torch.randn(1, source.config.time_cond_proj_dim)
|
||||||
|
|
||||||
source_args = {
|
source_args = {
|
||||||
"positional": (x, timestep, clip_text_embeddings),
|
"positional": (x, timestep, clip_text_embeddings),
|
||||||
"keyword": {"added_cond_kwargs": added_cond_kwargs} if source_has_time_ids else {},
|
"keyword": source_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
converter = ModelConverter(
|
converter = ModelConverter(
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
from refiners.fluxion.context import Contexts
|
||||||
|
from refiners.foundationals.latent_diffusion.range_adapter import RangeEncoder
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sinusoidal_embedding(
|
||||||
|
x: torch.Tensor,
|
||||||
|
embedding_dim: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Differences from compute_sinusoidal_embedding in RangeAdapter:
|
||||||
|
# - we concat [sin, cos], it does the opposite ([cos, sin])
|
||||||
|
# - we divide the exponent by half_dim - 1, it divides by half_dim
|
||||||
|
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
# Note: it is important that this computation is done in float32.
|
||||||
|
# The result can be cast to lower precision later if necessary.
|
||||||
|
exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=x.device)
|
||||||
|
exponent /= half_dim - 1
|
||||||
|
embedding = x.unsqueeze(1).float() * torch.exp(exponent).unsqueeze(0)
|
||||||
|
embedding = torch.cat([torch.sin(embedding), torch.cos(embedding)], dim=-1)
|
||||||
|
|
||||||
|
assert embedding.shape == (x.shape[0], embedding_dim)
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(fl.Residual):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
fl.UseContext("lcm", "condition_scale_embedding"),
|
||||||
|
fl.Converter(),
|
||||||
|
fl.Linear(in_features=in_channels, out_features=out_channels, bias=False, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LcmAdapter(fl.Chain, Adapter[SDXLUNet]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
target: SDXLUNet,
|
||||||
|
condition_scale_embedding_dim: int = 256,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
) -> None:
|
||||||
|
assert condition_scale_embedding_dim % 2 == 0
|
||||||
|
self.condition_scale_embedding_dim = condition_scale_embedding_dim
|
||||||
|
self.condition_scale = condition_scale
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(target)
|
||||||
|
|
||||||
|
def init_context(self) -> Contexts:
|
||||||
|
return {"lcm": {"condition_scale_embedding": self.sinusoidal_embedding}}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sinusoidal_embedding(self) -> torch.Tensor:
|
||||||
|
return compute_sinusoidal_embedding(
|
||||||
|
torch.tensor([(self.condition_scale - 1) * 1000], device=self.device),
|
||||||
|
embedding_dim=self.condition_scale_embedding_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_condition_scale(self, scale: float) -> None:
|
||||||
|
self.condition_scale = scale
|
||||||
|
self.set_context("lcm", {"condition_scale_embedding": self.sinusoidal_embedding})
|
||||||
|
|
||||||
|
def inject(self: "LcmAdapter", parent: fl.Chain | None = None) -> "LcmAdapter":
|
||||||
|
ra = self.target.ensure_find(RangeEncoder)
|
||||||
|
block = ResidualBlock(
|
||||||
|
in_channels=self.condition_scale_embedding_dim,
|
||||||
|
out_channels=ra.sinusoidal_embedding_dim,
|
||||||
|
device=self.target.device,
|
||||||
|
dtype=self.target.dtype,
|
||||||
|
)
|
||||||
|
ra.insert_before_type(fl.Linear, block)
|
||||||
|
return super().inject(parent)
|
||||||
|
|
||||||
|
def eject(self) -> None:
|
||||||
|
ra = self.target.ensure_find(RangeEncoder)
|
||||||
|
ra.remove(ra.ensure_find(ResidualBlock))
|
||||||
|
super().eject()
|
Loading…
Reference in a new issue