refactor latent_diffusion module

This commit is contained in:
limiteinductive 2023-08-23 00:36:29 +02:00 committed by Benjamin Trom
parent 3ee0ccccdc
commit 92a21bc21e
25 changed files with 331 additions and 263 deletions

View file

@ -249,7 +249,7 @@ lora_weights.patch(sd15, scale=1.0)
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(30)

View file

@ -6,19 +6,19 @@ from refiners.fluxion.utils import (
convert_state_dict,
save_to_safetensors,
)
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion import UNet
from refiners.foundationals.latent_diffusion import SD1UNet
@torch.no_grad()
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
controlnet = Controlnet(name="mycn")
controlnet = SD1Controlnet(name="mycn")
condition = torch.randn(1, 3, 512, 512)
controlnet.set_controlnet_condition(condition=condition)
unet = UNet(in_channels=4, clip_embedding_dim=768)
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
unet.insert(index=0, module=controlnet)
clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

View file

@ -9,7 +9,7 @@ from torch.nn import Parameter as TorchParameter
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import save_to_safetensors
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
from refiners.adapters.lora import Lora
from refiners.fluxion.utils import create_state_dict_mapping
@ -37,7 +37,7 @@ def process(source: str, base_model: str, output_file: str) -> None:
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
target = LoraTarget.CrossAttention
metadata = {"unet_targets": "CrossAttentionBlock2d"}
rank = diffusers_state_dict[

View file

@ -5,7 +5,7 @@ from refiners.fluxion.utils import (
)
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget
from refiners.fluxion.layers.module import Module
import refiners.fluxion.layers as fl
@ -19,7 +19,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
@torch.no_grad()
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: SD1UNet) -> dict[str, str] | None:
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])
clip_text_embeddings = torch.randn(1, 77, 768)
@ -79,7 +79,7 @@ def main() -> None:
match meta_key:
case "unet_targets":
src_model = diffusers_sd.unet # type: ignore
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
create_mapping = create_unet_mapping
key_prefix = "unet."
lora_prefix = "lora_unet_"

View file

@ -5,12 +5,12 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import StableDiffusionInpaintPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
dst_model = SD1UNet(in_channels=9, clip_embedding_dim=768)
x = torch.randn(1, 9, 32, 32)
timestep = torch.tensor(data=[0])

View file

@ -5,12 +5,12 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
@torch.no_grad()
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
x = torch.randn(1, 4, 32, 32)
timestep = torch.tensor(data=[0])

View file

@ -6,7 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
from diffusers import DiffusionPipeline # type: ignore
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
@torch.no_grad()

View file

@ -1,6 +1,7 @@
from torch import Tensor, arange, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
import refiners.foundationals.latent_diffusion.model as ldm
class TokenEncoder(fl.Embedding):
@ -121,7 +122,7 @@ class TransformerLayer(fl.Chain):
)
class CLIPTextEncoder(fl.Chain):
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface):
structural_attrs = [
"embedding_dim",
"max_sequence_length",
@ -189,10 +190,6 @@ class CLIPTextEncoder(fl.Chain):
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
@property
def unconditional_text_embedding(self) -> Tensor:
return self("")
class CLIPTextEncoderL(CLIPTextEncoder):
"""

View file

@ -1,201 +1,37 @@
from typing import TypeVar
from torch import cat, float32, randn, tensor, device as Device, dtype as DType, Size, Tensor
from PIL import Image
import numpy as np
from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.fluxion.layers.module import Module
from refiners.foundationals.latent_diffusion.auto_encoder import (
LatentDiffusionAutoencoder,
)
from refiners.foundationals.clip.text_encoder import (
CLIPTextEncoder,
CLIPTextEncoderL,
)
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
SD1UNet,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
SD1Controlnet,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import (
SDXLUNet,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
DoubleTextEncoder,
)
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
__all__ = [
"LatentDiffusionModel",
"UNet",
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1Controlnet",
"SDXLUNet",
"DoubleTextEncoder",
"DPMSolver",
"Scheduler",
"CLIPTextEncoder",
"CLIPTextEncoderL",
"LatentDiffusionAutoencoder",
]
class LatentDiffusionModel(Module):
def __init__(
self,
unet: UNet,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: CLIPTextEncoder,
scheduler: Scheduler,
device: Device | str = "cpu",
dtype: DType = float32,
):
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device)
self.dtype = dtype
self.unet = unet.to(self.device, dtype=self.dtype)
self.lda = lda.to(self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(self.device, dtype=self.dtype)
self.scheduler = scheduler.to(self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int):
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(
num_inference_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
).to(device=device, dtype=dtype)
def init_latents(
self,
size: tuple[int, int],
init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None,
) -> Tensor:
if noise is None:
height, width = size
noise = randn(1, 4, height // 8, width // 8, device=self.device)
assert list(noise.shape[2:]) == [
size[0] // 8,
size[1] // 8,
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(init_image.resize(size))
return self.scheduler.add_noise(encoded_image, noise, self.steps[first_step])
@property
def steps(self) -> list[int]:
return self.scheduler.steps
@property
def timestep_embeddings(self) -> Tensor:
return self.timestep_encoder(self.scheduler.timesteps)
@property
def unconditional_clip_text_embeddings(self) -> Tensor:
return self.clip_text_encoder.unconditional_text_embedding
def compute_text_embedding(self, text: str) -> Tensor:
return self.clip_text_encoder(text)
def forward(
self,
x: Tensor,
step: int,
clip_text_embedding: Tensor,
negative_clip_text_embedding: Tensor | None = None,
condition_scale: float = 7.5,
) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(0)
self.unet.set_timestep(timestep)
negative_clip_text_embedding = (
self.clip_text_encoder.unconditional_text_embedding
if negative_clip_text_embedding is None
else negative_clip_text_embedding
)
clip_text_embeddings = cat((negative_clip_text_embedding, clip_text_embedding))
self.unet.set_clip_text_embedding(clip_text_embeddings)
latents = cat((x, x)) # for classifier-free guidance
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
return self.scheduler(x, noise=noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
device=self.device,
dtype=self.dtype,
)
class StableDiffusion_1(LatentDiffusionModel):
def __init__(
self,
unet: UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
):
unet = unet or UNet(in_channels=4, clip_embedding_dim=768)
lda = lda or LatentDiffusionAutoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)
super().__init__(
unet,
lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__(
self,
unet: UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = float32,
):
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(unet, lda, clip_text_encoder, scheduler, device, dtype)
def forward(
self,
x: Tensor,
step: int,
clip_text_embedding: Tensor,
negative_clip_text_embedding: Tensor | None = None,
condition_scale: float = 7.5,
):
assert self.mask_latents is not None
assert self.target_image_latents is not None
x = cat((x, self.mask_latents, self.target_image_latents), dim=1)
return super().forward(x, step, clip_text_embedding, negative_clip_text_embedding, condition_scale)
def set_inpainting_conditions(
self,
target_image: Image.Image,
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
target_image = target_image.convert("RGB")
mask = mask.convert("L")
mask_tensor = tensor(np.array(mask).astype(np.float32) / 255.0).to(self.device)
mask_tensor = (mask_tensor > 0.5).unsqueeze(0).unsqueeze(0).to(dtype=self.dtype)
self.mask_latents = interpolate(mask_tensor, Size(latents_size))
init_image_tensor = image_to_tensor(target_image, device=self.device, dtype=self.dtype) * 2 - 1
masked_init_image = init_image_tensor * (1 - mask_tensor)
self.target_image_latents = self.lda.encode(masked_init_image)
return self.mask_latents, self.target_image_latents

View file

@ -9,7 +9,7 @@ from refiners.adapters.lora import LoraAdapter, load_lora_weights
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion import StableDiffusion_1
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
@ -74,7 +74,7 @@ class LoraWeights:
match meta_key:
case "unet_targets":
# TODO: support this transparently
if any([isinstance(module, Controlnet) for module in sd.unet]):
if any([isinstance(module, SD1Controlnet) for module in sd.unet]):
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
model = sd.unet
key_prefix = "unet."

View file

@ -0,0 +1,117 @@
from abc import ABC, abstractmethod
from typing import Protocol, TypeVar
from torch import Tensor, device as Device, dtype as DType
from PIL import Image
import torch
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
T = TypeVar("T", bound="fl.Module")
class UNetInterface(Protocol):
def set_timestep(self, timestep: Tensor) -> None:
...
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
...
def __call__(self, x: Tensor) -> Tensor:
...
class TextEncoderInterface(Protocol):
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
...
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
class LatentDiffusionModel(fl.Module, ABC):
def __init__(
self,
unet: UNetInterface,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: TextEncoderInterface,
scheduler: Scheduler,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
assert isinstance(unet, fl.Module)
self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype)
assert isinstance(clip_text_encoder, fl.Module)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
def set_num_inference_steps(self, num_inference_steps: int) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(
num_inference_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
).to(device=device, dtype=dtype)
def init_latents(
self,
size: tuple[int, int],
init_image: Image.Image | None = None,
first_step: int = 0,
noise: Tensor | None = None,
) -> Tensor:
if noise is None:
height, width = size
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
assert list(noise.shape[2:]) == [
size[0] // 8,
size[1] // 8,
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=size))
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
@property
def steps(self) -> list[int]:
return self.scheduler.steps
@abstractmethod
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None:
...
def forward(
self,
x: Tensor,
step: int,
clip_text_embedding: Tensor,
*args: Tensor,
condition_scale: float = 7.5,
) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args)
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
return self.scheduler(x, noise=noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
device=self.device,
dtype=self.dtype,
)

View file

@ -10,7 +10,7 @@ from refiners.fluxion.layers import (
Parallel,
)
from refiners.adapters.adapter import Adapter
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from torch import Tensor
@ -58,7 +58,7 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
class SelfAttentionInjection(Passthrough):
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
def __init__(self, unet: UNet, style_cfg: float = 0.5) -> None:
def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None:
# the style_cfg is the weight of the guide in unconditionned diffusion.
# This value is recommended to be 0.5 on the sdwebui repo.
self.style_cfg = style_cfg

View file

@ -0,0 +1,13 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
__all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1UNet",
"SD1Controlnet",
]

View file

@ -1,6 +1,11 @@
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
from refiners.foundationals.latent_diffusion.unet import DownBlocks, MiddleBlock, ResidualBlock, TimestepEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
DownBlocks,
MiddleBlock,
ResidualBlock,
TimestepEncoder,
)
from refiners.adapters.range_adapter import RangeAdapter2d
from typing import cast, Iterable
from torch import Tensor, device as Device, dtype as DType
@ -64,7 +69,7 @@ class ConditionEncoder(Chain):
)
class Controlnet(Passthrough):
class SD1Controlnet(Passthrough):
structural_attrs = ["name", "scale"]
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:

View file

@ -0,0 +1,105 @@
import torch
from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from PIL import Image
import numpy as np
from torch import device as Device, dtype as DType, Tensor
class StableDiffusion_1(LatentDiffusionModel):
unet: SD1UNet
clip_text_encoder: CLIPTextEncoderL
def __init__(
self,
unet: SD1UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768)
lda = lda or LatentDiffusionAutoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)
super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
negative_embedding = self.clip_text_encoder(negative_text or "")
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None:
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
class StableDiffusion_1_Inpainting(StableDiffusion_1):
def __init__(
self,
unet: SD1UNet | None = None,
lda: LatentDiffusionAutoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
)
def forward(
self,
x: Tensor,
step: int,
clip_text_embedding: Tensor,
*_: Tensor,
condition_scale: float = 7.5,
) -> Tensor:
assert self.mask_latents is not None
assert self.target_image_latents is not None
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
return super().forward(
x=x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=condition_scale,
)
def set_inpainting_conditions(
self,
target_image: Image.Image,
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
target_image = target_image.convert(mode="RGB")
mask = mask.convert(mode="L")
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
masked_init_image = init_image_tensor * (1 - mask_tensor)
self.target_image_latents = self.lda.encode(x=masked_init_image)
return self.mask_latents, self.target_image_latents

View file

@ -7,6 +7,7 @@ import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
import refiners.foundationals.latent_diffusion.model as ldm
class TimestepEncoder(fl.Passthrough):
@ -242,7 +243,7 @@ class ResidualConcatenator(fl.Chain):
)
class UNet(fl.Chain):
class SD1UNet(fl.Chain, ldm.UNetInterface):
structural_attrs = ["in_channels", "clip_embedding_dim"]
def __init__(

View file

@ -0,0 +1,8 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
__all__ = [
"SDXLUNet",
"DoubleTextEncoder",
]

View file

@ -7,6 +7,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
from jaxtyping import Float
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
@ -59,7 +60,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
return x[:, end_of_text_index[0], :]
class DoubleTextEncoder(fl.Chain):
class DoubleTextEncoder(fl.Chain, TextEncoderInterface):
def __init__(
self,
text_encoder_l: CLIPTextEncoderL | None = None,

View file

@ -3,7 +3,12 @@ from torch import Tensor, device as Device, dtype as DType
from refiners.fluxion.context import Contexts
import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.unet import ResidualAccumulator, ResidualBlock, ResidualConcatenator
from refiners.foundationals.latent_diffusion.model import UNetInterface
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
ResidualAccumulator,
ResidualBlock,
ResidualConcatenator,
)
from refiners.adapters.range_adapter import RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding
@ -242,7 +247,7 @@ class OutputBlock(fl.Chain):
)
class SDXLUNet(fl.Chain):
class SDXLUNet(fl.Chain, UNetInterface):
structural_attrs = ["in_channels"]
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:

View file

@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
from loguru import logger
from torch.utils.data import Dataset
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from torchvision.transforms import RandomCrop # type: ignore
@ -103,9 +103,9 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
@cached_property
def unet(self) -> UNet:
def unet(self) -> SD1UNet:
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
return UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
@cached_property
def text_encoder(self) -> CLIPTextEncoderL:
@ -171,14 +171,12 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
for i in range(num_images_per_prompt):
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
x = randn(1, 4, 64, 64, device=self.device)
clip_text_embedding = sd.compute_text_embedding(text=prompt).to(device=self.device)
negative_clip_text_embedding = sd.compute_text_embedding(text="").to(device=self.device)
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device)
for step in sd.steps:
x = sd(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
images[prompt] = canvas_image

View file

@ -9,8 +9,8 @@ from pathlib import Path
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
from refiners.foundationals.latent_diffusion.lora import LoraWeights
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
@ -196,7 +196,7 @@ def sd15_inpainting(
warn("not running on CPU, skipping")
pytest.skip()
unet = UNet(in_channels=9, clip_embedding_dim=768)
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
@ -214,7 +214,7 @@ def sd15_inpainting_float16(
warn("not running on CPU, skipping")
pytest.skip()
unet = UNet(in_channels=9, clip_embedding_dim=768)
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
@ -253,8 +253,7 @@ def test_diffusion_std_random_init(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
@ -267,7 +266,6 @@ def test_diffusion_std_random_init(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -286,11 +284,9 @@ def test_diffusion_std_random_init_float16(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
assert negative_clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps)
@ -303,7 +299,6 @@ def test_diffusion_std_random_init_float16(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -325,8 +320,7 @@ def test_diffusion_std_init_image(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
@ -339,7 +333,6 @@ def test_diffusion_std_init_image(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -362,8 +355,7 @@ def test_diffusion_inpainting(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
@ -377,7 +369,6 @@ def test_diffusion_inpainting(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -401,11 +392,9 @@ def test_diffusion_inpainting_float16(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16
assert negative_clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps)
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
@ -419,7 +408,6 @@ def test_diffusion_inpainting_float16(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -447,13 +435,12 @@ def test_diffusion_controlnet(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = Controlnet(name=cn_name, device=test_device)
controlnet = SD1Controlnet(name=cn_name, device=test_device)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
@ -470,7 +457,6 @@ def test_diffusion_controlnet(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -498,13 +484,12 @@ def test_diffusion_controlnet_structural_copy(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = Controlnet(name=cn_name, device=test_device)
controlnet = SD1Controlnet(name=cn_name, device=test_device)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
@ -521,7 +506,6 @@ def test_diffusion_controlnet_structural_copy(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -548,13 +532,12 @@ def test_diffusion_controlnet_float16(
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
controlnet_state_dict = load_from_safetensors(cn_weights_path)
controlnet = Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
controlnet.load_state_dict(controlnet_state_dict)
controlnet.set_scale(0.5)
sd15.unet.insert(0, controlnet)
@ -571,7 +554,6 @@ def test_diffusion_controlnet_float16(
x,
step=step,
clip_text_embedding=clip_text_embedding,
negative_clip_text_embedding=negative_clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
@ -597,7 +579,7 @@ def test_diffusion_lora(
prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps)
@ -631,7 +613,7 @@ def test_diffusion_refonly(
prompt = "Chicken"
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet)
sai.inject()
@ -673,7 +655,7 @@ def test_diffusion_inpainting_refonly(
prompt = "" # unconditional
with torch.no_grad():
clip_text_embedding = sd15.compute_text_embedding(prompt)
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = SelfAttentionInjection(sd15.unet)
sai.inject()

View file

@ -8,7 +8,7 @@ from torch import Tensor
from refiners.fluxion.utils import manual_seed
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
from refiners.foundationals.latent_diffusion.sdxl_text_encoder import DoubleTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
class DiffusersSDXL(Protocol):

View file

@ -4,7 +4,7 @@ from warnings import warn
import pytest
import torch
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.fluxion.utils import compare_models

View file

@ -1,4 +1,4 @@
from refiners.foundationals.latent_diffusion.unet import UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.fluxion import manual_seed
import torch
@ -9,7 +9,7 @@ def test_unet_context_flush():
timestep = torch.randint(0, 999, size=(1, 1))
x = torch.randn(1, 4, 32, 32)
unet = UNet(in_channels=4, clip_embedding_dim=768)
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad():