mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
refactor latent_diffusion module
This commit is contained in:
parent
3ee0ccccdc
commit
92a21bc21e
|
@ -249,7 +249,7 @@ lora_weights.patch(sd15, scale=1.0)
|
||||||
prompt = "a cute cat"
|
prompt = "a cute cat"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sd15.set_num_inference_steps(30)
|
sd15.set_num_inference_steps(30)
|
||||||
|
|
||||||
|
|
|
@ -6,19 +6,19 @@ from refiners.fluxion.utils import (
|
||||||
convert_state_dict,
|
convert_state_dict,
|
||||||
save_to_safetensors,
|
save_to_safetensors,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion import UNet
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||||
controlnet = Controlnet(name="mycn")
|
controlnet = SD1Controlnet(name="mycn")
|
||||||
|
|
||||||
condition = torch.randn(1, 3, 512, 512)
|
condition = torch.randn(1, 3, 512, 512)
|
||||||
controlnet.set_controlnet_condition(condition=condition)
|
controlnet.set_controlnet_condition(condition=condition)
|
||||||
|
|
||||||
unet = UNet(in_channels=4, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
unet.insert(index=0, module=controlnet)
|
unet.insert(index=0, module=controlnet)
|
||||||
clip_text_embedding = torch.rand(1, 77, 768)
|
clip_text_embedding = torch.rand(1, 77, 768)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from torch.nn import Parameter as TorchParameter
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
||||||
from refiners.adapters.lora import Lora
|
from refiners.adapters.lora import Lora
|
||||||
from refiners.fluxion.utils import create_state_dict_mapping
|
from refiners.fluxion.utils import create_state_dict_mapping
|
||||||
|
@ -37,7 +37,7 @@ def process(source: str, base_model: str, output_file: str) -> None:
|
||||||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
|
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
|
||||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||||
|
|
||||||
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
|
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
target = LoraTarget.CrossAttention
|
target = LoraTarget.CrossAttention
|
||||||
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
||||||
rank = diffusers_state_dict[
|
rank = diffusers_state_dict[
|
||||||
|
|
|
@ -5,7 +5,7 @@ from refiners.fluxion.utils import (
|
||||||
)
|
)
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||||
from refiners.fluxion.layers.module import Module
|
from refiners.fluxion.layers.module import Module
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
|
@ -19,7 +19,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
|
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: SD1UNet) -> dict[str, str] | None:
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor(data=[0])
|
timestep = torch.tensor(data=[0])
|
||||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||||
|
@ -79,7 +79,7 @@ def main() -> None:
|
||||||
match meta_key:
|
match meta_key:
|
||||||
case "unet_targets":
|
case "unet_targets":
|
||||||
src_model = diffusers_sd.unet # type: ignore
|
src_model = diffusers_sd.unet # type: ignore
|
||||||
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
create_mapping = create_unet_mapping
|
create_mapping = create_unet_mapping
|
||||||
key_prefix = "unet."
|
key_prefix = "unet."
|
||||||
lora_prefix = "lora_unet_"
|
lora_prefix = "lora_unet_"
|
||||||
|
|
|
@ -5,12 +5,12 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
||||||
from diffusers import StableDiffusionInpaintPipeline # type: ignore
|
from diffusers import StableDiffusionInpaintPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
|
dst_model = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
|
|
||||||
x = torch.randn(1, 9, 32, 32)
|
x = torch.randn(1, 9, 32, 32)
|
||||||
timestep = torch.tensor(data=[0])
|
timestep = torch.tensor(data=[0])
|
||||||
|
|
|
@ -5,12 +5,12 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||||
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
timestep = torch.tensor(data=[0])
|
timestep = torch.tensor(data=[0])
|
||||||
|
|
|
@ -6,7 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
||||||
from diffusers import DiffusionPipeline # type: ignore
|
from diffusers import DiffusionPipeline # type: ignore
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from torch import Tensor, arange, device as Device, dtype as DType
|
from torch import Tensor, arange, device as Device, dtype as DType
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
import refiners.foundationals.latent_diffusion.model as ldm
|
||||||
|
|
||||||
|
|
||||||
class TokenEncoder(fl.Embedding):
|
class TokenEncoder(fl.Embedding):
|
||||||
|
@ -121,7 +122,7 @@ class TransformerLayer(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoder(fl.Chain):
|
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface):
|
||||||
structural_attrs = [
|
structural_attrs = [
|
||||||
"embedding_dim",
|
"embedding_dim",
|
||||||
"max_sequence_length",
|
"max_sequence_length",
|
||||||
|
@ -189,10 +190,6 @@ class CLIPTextEncoder(fl.Chain):
|
||||||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||||
|
|
||||||
@property
|
|
||||||
def unconditional_text_embedding(self) -> Tensor:
|
|
||||||
return self("")
|
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,201 +1,37 @@
|
||||||
from typing import TypeVar
|
|
||||||
from torch import cat, float32, randn, tensor, device as Device, dtype as DType, Size, Tensor
|
|
||||||
from PIL import Image
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, interpolate
|
|
||||||
from refiners.fluxion.layers.module import Module
|
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import (
|
from refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||||
LatentDiffusionAutoencoder,
|
LatentDiffusionAutoencoder,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.clip.text_encoder import (
|
from refiners.foundationals.clip.text_encoder import (
|
||||||
CLIPTextEncoder,
|
|
||||||
CLIPTextEncoderL,
|
CLIPTextEncoderL,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_1_Inpainting,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
|
SD1UNet,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
|
||||||
|
SD1Controlnet,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import (
|
||||||
|
SDXLUNet,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
|
||||||
|
DoubleTextEncoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LatentDiffusionModel",
|
"StableDiffusion_1",
|
||||||
"UNet",
|
"StableDiffusion_1_Inpainting",
|
||||||
|
"SD1UNet",
|
||||||
|
"SD1Controlnet",
|
||||||
|
"SDXLUNet",
|
||||||
|
"DoubleTextEncoder",
|
||||||
"DPMSolver",
|
"DPMSolver",
|
||||||
"Scheduler",
|
"Scheduler",
|
||||||
"CLIPTextEncoder",
|
"CLIPTextEncoderL",
|
||||||
"LatentDiffusionAutoencoder",
|
"LatentDiffusionAutoencoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionModel(Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
unet: UNet,
|
|
||||||
lda: LatentDiffusionAutoencoder,
|
|
||||||
clip_text_encoder: CLIPTextEncoder,
|
|
||||||
scheduler: Scheduler,
|
|
||||||
device: Device | str = "cpu",
|
|
||||||
dtype: DType = float32,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.device: Device = device if isinstance(device, Device) else Device(device)
|
|
||||||
self.dtype = dtype
|
|
||||||
self.unet = unet.to(self.device, dtype=self.dtype)
|
|
||||||
self.lda = lda.to(self.device, dtype=self.dtype)
|
|
||||||
self.clip_text_encoder = clip_text_encoder.to(self.device, dtype=self.dtype)
|
|
||||||
self.scheduler = scheduler.to(self.device, dtype=self.dtype)
|
|
||||||
|
|
||||||
def set_num_inference_steps(self, num_inference_steps: int):
|
|
||||||
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
|
||||||
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
|
||||||
device, dtype = self.scheduler.device, self.scheduler.dtype
|
|
||||||
self.scheduler = self.scheduler.__class__(
|
|
||||||
num_inference_steps,
|
|
||||||
initial_diffusion_rate=initial_diffusion_rate,
|
|
||||||
final_diffusion_rate=final_diffusion_rate,
|
|
||||||
).to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
def init_latents(
|
|
||||||
self,
|
|
||||||
size: tuple[int, int],
|
|
||||||
init_image: Image.Image | None = None,
|
|
||||||
first_step: int = 0,
|
|
||||||
noise: Tensor | None = None,
|
|
||||||
) -> Tensor:
|
|
||||||
if noise is None:
|
|
||||||
height, width = size
|
|
||||||
noise = randn(1, 4, height // 8, width // 8, device=self.device)
|
|
||||||
assert list(noise.shape[2:]) == [
|
|
||||||
size[0] // 8,
|
|
||||||
size[1] // 8,
|
|
||||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
|
||||||
if init_image is None:
|
|
||||||
return noise
|
|
||||||
encoded_image = self.lda.encode_image(init_image.resize(size))
|
|
||||||
return self.scheduler.add_noise(encoded_image, noise, self.steps[first_step])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def steps(self) -> list[int]:
|
|
||||||
return self.scheduler.steps
|
|
||||||
|
|
||||||
@property
|
|
||||||
def timestep_embeddings(self) -> Tensor:
|
|
||||||
return self.timestep_encoder(self.scheduler.timesteps)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def unconditional_clip_text_embeddings(self) -> Tensor:
|
|
||||||
return self.clip_text_encoder.unconditional_text_embedding
|
|
||||||
|
|
||||||
def compute_text_embedding(self, text: str) -> Tensor:
|
|
||||||
return self.clip_text_encoder(text)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
step: int,
|
|
||||||
clip_text_embedding: Tensor,
|
|
||||||
negative_clip_text_embedding: Tensor | None = None,
|
|
||||||
condition_scale: float = 7.5,
|
|
||||||
) -> Tensor:
|
|
||||||
timestep = self.scheduler.timesteps[step].unsqueeze(0)
|
|
||||||
self.unet.set_timestep(timestep)
|
|
||||||
|
|
||||||
negative_clip_text_embedding = (
|
|
||||||
self.clip_text_encoder.unconditional_text_embedding
|
|
||||||
if negative_clip_text_embedding is None
|
|
||||||
else negative_clip_text_embedding
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_text_embeddings = cat((negative_clip_text_embedding, clip_text_embedding))
|
|
||||||
|
|
||||||
self.unet.set_clip_text_embedding(clip_text_embeddings)
|
|
||||||
latents = cat((x, x)) # for classifier-free guidance
|
|
||||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
|
||||||
|
|
||||||
# classifier-free guidance
|
|
||||||
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
|
||||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
|
||||||
return self.scheduler(x, noise=noise, step=step)
|
|
||||||
|
|
||||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
|
||||||
return self.__class__(
|
|
||||||
unet=self.unet.structural_copy(),
|
|
||||||
lda=self.lda.structural_copy(),
|
|
||||||
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
|
||||||
scheduler=self.scheduler,
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion_1(LatentDiffusionModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
unet: UNet | None = None,
|
|
||||||
lda: LatentDiffusionAutoencoder | None = None,
|
|
||||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
|
||||||
scheduler: Scheduler | None = None,
|
|
||||||
device: Device | str = "cpu",
|
|
||||||
dtype: DType = float32,
|
|
||||||
):
|
|
||||||
unet = unet or UNet(in_channels=4, clip_embedding_dim=768)
|
|
||||||
lda = lda or LatentDiffusionAutoencoder()
|
|
||||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
|
||||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
unet,
|
|
||||||
lda,
|
|
||||||
clip_text_encoder=clip_text_encoder,
|
|
||||||
scheduler=scheduler,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
unet: UNet | None = None,
|
|
||||||
lda: LatentDiffusionAutoencoder | None = None,
|
|
||||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
|
||||||
scheduler: Scheduler | None = None,
|
|
||||||
device: Device | str = "cpu",
|
|
||||||
dtype: DType = float32,
|
|
||||||
):
|
|
||||||
self.mask_latents: Tensor | None = None
|
|
||||||
self.target_image_latents: Tensor | None = None
|
|
||||||
super().__init__(unet, lda, clip_text_encoder, scheduler, device, dtype)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
step: int,
|
|
||||||
clip_text_embedding: Tensor,
|
|
||||||
negative_clip_text_embedding: Tensor | None = None,
|
|
||||||
condition_scale: float = 7.5,
|
|
||||||
):
|
|
||||||
assert self.mask_latents is not None
|
|
||||||
assert self.target_image_latents is not None
|
|
||||||
x = cat((x, self.mask_latents, self.target_image_latents), dim=1)
|
|
||||||
return super().forward(x, step, clip_text_embedding, negative_clip_text_embedding, condition_scale)
|
|
||||||
|
|
||||||
def set_inpainting_conditions(
|
|
||||||
self,
|
|
||||||
target_image: Image.Image,
|
|
||||||
mask: Image.Image,
|
|
||||||
latents_size: tuple[int, int] = (64, 64),
|
|
||||||
) -> tuple[Tensor, Tensor]:
|
|
||||||
target_image = target_image.convert("RGB")
|
|
||||||
mask = mask.convert("L")
|
|
||||||
|
|
||||||
mask_tensor = tensor(np.array(mask).astype(np.float32) / 255.0).to(self.device)
|
|
||||||
mask_tensor = (mask_tensor > 0.5).unsqueeze(0).unsqueeze(0).to(dtype=self.dtype)
|
|
||||||
self.mask_latents = interpolate(mask_tensor, Size(latents_size))
|
|
||||||
|
|
||||||
init_image_tensor = image_to_tensor(target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
|
||||||
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
|
||||||
self.target_image_latents = self.lda.encode(masked_init_image)
|
|
||||||
|
|
||||||
return self.mask_latents, self.target_image_latents
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from refiners.adapters.lora import LoraAdapter, load_lora_weights
|
||||||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ class LoraWeights:
|
||||||
match meta_key:
|
match meta_key:
|
||||||
case "unet_targets":
|
case "unet_targets":
|
||||||
# TODO: support this transparently
|
# TODO: support this transparently
|
||||||
if any([isinstance(module, Controlnet) for module in sd.unet]):
|
if any([isinstance(module, SD1Controlnet) for module in sd.unet]):
|
||||||
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
|
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
|
||||||
model = sd.unet
|
model = sd.unet
|
||||||
key_prefix = "unet."
|
key_prefix = "unet."
|
||||||
|
|
117
src/refiners/foundationals/latent_diffusion/model.py
Normal file
117
src/refiners/foundationals/latent_diffusion/model.py
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Protocol, TypeVar
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
import refiners.fluxion.layers as fl
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="fl.Module")
|
||||||
|
|
||||||
|
|
||||||
|
class UNetInterface(Protocol):
|
||||||
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor) -> Tensor:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncoderInterface(Protocol):
|
||||||
|
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||||
|
|
||||||
|
|
||||||
|
class LatentDiffusionModel(fl.Module, ABC):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: UNetInterface,
|
||||||
|
lda: LatentDiffusionAutoencoder,
|
||||||
|
clip_text_encoder: TextEncoderInterface,
|
||||||
|
scheduler: Scheduler,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||||
|
self.dtype = dtype
|
||||||
|
assert isinstance(unet, fl.Module)
|
||||||
|
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||||
|
assert isinstance(clip_text_encoder, fl.Module)
|
||||||
|
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
def set_num_inference_steps(self, num_inference_steps: int) -> None:
|
||||||
|
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||||
|
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||||
|
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||||
|
self.scheduler = self.scheduler.__class__(
|
||||||
|
num_inference_steps,
|
||||||
|
initial_diffusion_rate=initial_diffusion_rate,
|
||||||
|
final_diffusion_rate=final_diffusion_rate,
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def init_latents(
|
||||||
|
self,
|
||||||
|
size: tuple[int, int],
|
||||||
|
init_image: Image.Image | None = None,
|
||||||
|
first_step: int = 0,
|
||||||
|
noise: Tensor | None = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if noise is None:
|
||||||
|
height, width = size
|
||||||
|
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
||||||
|
assert list(noise.shape[2:]) == [
|
||||||
|
size[0] // 8,
|
||||||
|
size[1] // 8,
|
||||||
|
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||||
|
if init_image is None:
|
||||||
|
return noise
|
||||||
|
encoded_image = self.lda.encode_image(image=init_image.resize(size=size))
|
||||||
|
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return self.scheduler.steps
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
*args: Tensor,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
) -> Tensor:
|
||||||
|
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||||
|
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args)
|
||||||
|
|
||||||
|
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||||
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
||||||
|
# classifier-free guidance
|
||||||
|
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
||||||
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
return self.scheduler(x, noise=noise, step=step)
|
||||||
|
|
||||||
|
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||||
|
return self.__class__(
|
||||||
|
unet=self.unet.structural_copy(),
|
||||||
|
lda=self.lda.structural_copy(),
|
||||||
|
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||||
|
scheduler=self.scheduler,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
|
@ -10,7 +10,7 @@ from refiners.fluxion.layers import (
|
||||||
Parallel,
|
Parallel,
|
||||||
)
|
)
|
||||||
from refiners.adapters.adapter import Adapter
|
from refiners.adapters.adapter import Adapter
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
||||||
class SelfAttentionInjection(Passthrough):
|
class SelfAttentionInjection(Passthrough):
|
||||||
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||||
|
|
||||||
def __init__(self, unet: UNet, style_cfg: float = 0.5) -> None:
|
def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||||
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||||
# This value is recommended to be 0.5 on the sdwebui repo.
|
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||||
self.style_cfg = style_cfg
|
self.style_cfg = style_cfg
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_1_Inpainting,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StableDiffusion_1",
|
||||||
|
"StableDiffusion_1_Inpainting",
|
||||||
|
"SD1UNet",
|
||||||
|
"SD1Controlnet",
|
||||||
|
]
|
|
@ -1,6 +1,11 @@
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
||||||
from refiners.foundationals.latent_diffusion.unet import DownBlocks, MiddleBlock, ResidualBlock, TimestepEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
|
DownBlocks,
|
||||||
|
MiddleBlock,
|
||||||
|
ResidualBlock,
|
||||||
|
TimestepEncoder,
|
||||||
|
)
|
||||||
from refiners.adapters.range_adapter import RangeAdapter2d
|
from refiners.adapters.range_adapter import RangeAdapter2d
|
||||||
from typing import cast, Iterable
|
from typing import cast, Iterable
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
@ -64,7 +69,7 @@ class ConditionEncoder(Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Controlnet(Passthrough):
|
class SD1Controlnet(Passthrough):
|
||||||
structural_attrs = ["name", "scale"]
|
structural_attrs = ["name", "scale"]
|
||||||
|
|
||||||
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
|
@ -0,0 +1,105 @@
|
||||||
|
import torch
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, interpolate
|
||||||
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from torch import device as Device, dtype as DType, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_1(LatentDiffusionModel):
|
||||||
|
unet: SD1UNet
|
||||||
|
clip_text_encoder: CLIPTextEncoderL
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: SD1UNet | None = None,
|
||||||
|
lda: LatentDiffusionAutoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
|
lda = lda or LatentDiffusionAutoencoder()
|
||||||
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
|
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
unet=unet,
|
||||||
|
lda=lda,
|
||||||
|
clip_text_encoder=clip_text_encoder,
|
||||||
|
scheduler=scheduler,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
|
||||||
|
conditional_embedding = self.clip_text_encoder(text)
|
||||||
|
if text == negative_text:
|
||||||
|
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
|
||||||
|
|
||||||
|
negative_embedding = self.clip_text_encoder(negative_text or "")
|
||||||
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
||||||
|
|
||||||
|
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None:
|
||||||
|
self.unet.set_timestep(timestep=timestep)
|
||||||
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
unet: SD1UNet | None = None,
|
||||||
|
lda: LatentDiffusionAutoencoder | None = None,
|
||||||
|
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||||
|
scheduler: Scheduler | None = None,
|
||||||
|
device: Device | str = "cpu",
|
||||||
|
dtype: DType = torch.float32,
|
||||||
|
) -> None:
|
||||||
|
self.mask_latents: Tensor | None = None
|
||||||
|
self.target_image_latents: Tensor | None = None
|
||||||
|
super().__init__(
|
||||||
|
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
step: int,
|
||||||
|
clip_text_embedding: Tensor,
|
||||||
|
*_: Tensor,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
) -> Tensor:
|
||||||
|
assert self.mask_latents is not None
|
||||||
|
assert self.target_image_latents is not None
|
||||||
|
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
|
||||||
|
return super().forward(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=condition_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_inpainting_conditions(
|
||||||
|
self,
|
||||||
|
target_image: Image.Image,
|
||||||
|
mask: Image.Image,
|
||||||
|
latents_size: tuple[int, int] = (64, 64),
|
||||||
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
target_image = target_image.convert(mode="RGB")
|
||||||
|
mask = mask.convert(mode="L")
|
||||||
|
|
||||||
|
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
|
||||||
|
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
|
||||||
|
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
|
||||||
|
|
||||||
|
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||||
|
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||||
|
self.target_image_latents = self.lda.encode(x=masked_init_image)
|
||||||
|
|
||||||
|
return self.mask_latents, self.target_image_latents
|
|
@ -7,6 +7,7 @@ import refiners.fluxion.layers as fl
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
||||||
|
import refiners.foundationals.latent_diffusion.model as ldm
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(fl.Passthrough):
|
class TimestepEncoder(fl.Passthrough):
|
||||||
|
@ -242,7 +243,7 @@ class ResidualConcatenator(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UNet(fl.Chain):
|
class SD1UNet(fl.Chain, ldm.UNetInterface):
|
||||||
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
|
@ -0,0 +1,8 @@
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SDXLUNet",
|
||||||
|
"DoubleTextEncoder",
|
||||||
|
]
|
|
@ -7,6 +7,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
|
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
|
||||||
|
|
||||||
|
|
||||||
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
|
@ -59,7 +60,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||||
return x[:, end_of_text_index[0], :]
|
return x[:, end_of_text_index[0], :]
|
||||||
|
|
||||||
|
|
||||||
class DoubleTextEncoder(fl.Chain):
|
class DoubleTextEncoder(fl.Chain, TextEncoderInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_encoder_l: CLIPTextEncoderL | None = None,
|
text_encoder_l: CLIPTextEncoderL | None = None,
|
|
@ -3,7 +3,12 @@ from torch import Tensor, device as Device, dtype as DType
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||||
from refiners.foundationals.latent_diffusion.unet import ResidualAccumulator, ResidualBlock, ResidualConcatenator
|
from refiners.foundationals.latent_diffusion.model import UNetInterface
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||||
|
ResidualAccumulator,
|
||||||
|
ResidualBlock,
|
||||||
|
ResidualConcatenator,
|
||||||
|
)
|
||||||
from refiners.adapters.range_adapter import RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding
|
from refiners.adapters.range_adapter import RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding
|
||||||
|
|
||||||
|
|
||||||
|
@ -242,7 +247,7 @@ class OutputBlock(fl.Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SDXLUNet(fl.Chain):
|
class SDXLUNet(fl.Chain, UNetInterface):
|
||||||
structural_attrs = ["in_channels"]
|
structural_attrs = ["in_channels"]
|
||||||
|
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
|
@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
||||||
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
from torchvision.transforms import RandomCrop # type: ignore
|
from torchvision.transforms import RandomCrop # type: ignore
|
||||||
|
@ -103,9 +103,9 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
||||||
|
|
||||||
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
@cached_property
|
@cached_property
|
||||||
def unet(self) -> UNet:
|
def unet(self) -> SD1UNet:
|
||||||
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
||||||
return UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def text_encoder(self) -> CLIPTextEncoderL:
|
def text_encoder(self) -> CLIPTextEncoderL:
|
||||||
|
@ -171,14 +171,12 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
for i in range(num_images_per_prompt):
|
for i in range(num_images_per_prompt):
|
||||||
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
|
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
|
||||||
x = randn(1, 4, 64, 64, device=self.device)
|
x = randn(1, 4, 64, 64, device=self.device)
|
||||||
clip_text_embedding = sd.compute_text_embedding(text=prompt).to(device=self.device)
|
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device)
|
||||||
negative_clip_text_embedding = sd.compute_text_embedding(text="").to(device=self.device)
|
|
||||||
for step in sd.steps:
|
for step in sd.steps:
|
||||||
x = sd(
|
x = sd(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
)
|
)
|
||||||
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
|
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
|
||||||
images[prompt] = canvas_image
|
images[prompt] = canvas_image
|
||||||
|
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed
|
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed
|
||||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting
|
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||||
|
@ -196,7 +196,7 @@ def sd15_inpainting(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
unet = UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
@ -214,7 +214,7 @@ def sd15_inpainting_float16(
|
||||||
warn("not running on CPU, skipping")
|
warn("not running on CPU, skipping")
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
|
||||||
unet = UNet(in_channels=9, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||||
|
@ -253,8 +253,7 @@ def test_diffusion_std_random_init(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -267,7 +266,6 @@ def test_diffusion_std_random_init(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -286,11 +284,9 @@ def test_diffusion_std_random_init_float16(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
assert clip_text_embedding.dtype == torch.float16
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
assert negative_clip_text_embedding.dtype == torch.float16
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -303,7 +299,6 @@ def test_diffusion_std_random_init_float16(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -325,8 +320,7 @@ def test_diffusion_std_init_image(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -339,7 +333,6 @@ def test_diffusion_std_init_image(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -362,8 +355,7 @@ def test_diffusion_inpainting(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||||
|
@ -377,7 +369,6 @@ def test_diffusion_inpainting(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -401,11 +392,9 @@ def test_diffusion_inpainting_float16(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
assert clip_text_embedding.dtype == torch.float16
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
assert negative_clip_text_embedding.dtype == torch.float16
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||||
|
@ -419,7 +408,6 @@ def test_diffusion_inpainting_float16(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -447,13 +435,12 @@ def test_diffusion_controlnet(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||||
controlnet = Controlnet(name=cn_name, device=test_device)
|
controlnet = SD1Controlnet(name=cn_name, device=test_device)
|
||||||
controlnet.load_state_dict(controlnet_state_dict)
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
controlnet.set_scale(0.5)
|
controlnet.set_scale(0.5)
|
||||||
sd15.unet.insert(0, controlnet)
|
sd15.unet.insert(0, controlnet)
|
||||||
|
@ -470,7 +457,6 @@ def test_diffusion_controlnet(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -498,13 +484,12 @@ def test_diffusion_controlnet_structural_copy(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||||
controlnet = Controlnet(name=cn_name, device=test_device)
|
controlnet = SD1Controlnet(name=cn_name, device=test_device)
|
||||||
controlnet.load_state_dict(controlnet_state_dict)
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
controlnet.set_scale(0.5)
|
controlnet.set_scale(0.5)
|
||||||
sd15.unet.insert(0, controlnet)
|
sd15.unet.insert(0, controlnet)
|
||||||
|
@ -521,7 +506,6 @@ def test_diffusion_controlnet_structural_copy(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -548,13 +532,12 @@ def test_diffusion_controlnet_float16(
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||||
controlnet = Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
|
controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
|
||||||
controlnet.load_state_dict(controlnet_state_dict)
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
controlnet.set_scale(0.5)
|
controlnet.set_scale(0.5)
|
||||||
sd15.unet.insert(0, controlnet)
|
sd15.unet.insert(0, controlnet)
|
||||||
|
@ -571,7 +554,6 @@ def test_diffusion_controlnet_float16(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
@ -597,7 +579,7 @@ def test_diffusion_lora(
|
||||||
prompt = "a cute cat"
|
prompt = "a cute cat"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -631,7 +613,7 @@ def test_diffusion_refonly(
|
||||||
prompt = "Chicken"
|
prompt = "Chicken"
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = SelfAttentionInjection(sd15.unet)
|
sai = SelfAttentionInjection(sd15.unet)
|
||||||
sai.inject()
|
sai.inject()
|
||||||
|
@ -673,7 +655,7 @@ def test_diffusion_inpainting_refonly(
|
||||||
prompt = "" # unconditional
|
prompt = "" # unconditional
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = SelfAttentionInjection(sd15.unet)
|
sai = SelfAttentionInjection(sd15.unet)
|
||||||
sai.inject()
|
sai.inject()
|
||||||
|
|
|
@ -8,7 +8,7 @@ from torch import Tensor
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
||||||
from refiners.foundationals.latent_diffusion.sdxl_text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
|
||||||
|
|
||||||
class DiffusersSDXL(Protocol):
|
class DiffusersSDXL(Protocol):
|
||||||
|
|
|
@ -4,7 +4,7 @@ from warnings import warn
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||||
from refiners.fluxion.utils import compare_models
|
from refiners.fluxion.utils import compare_models
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ def test_unet_context_flush():
|
||||||
timestep = torch.randint(0, 999, size=(1, 1))
|
timestep = torch.randint(0, 999, size=(1, 1))
|
||||||
x = torch.randn(1, 4, 32, 32)
|
x = torch.randn(1, 4, 32, 32)
|
||||||
|
|
||||||
unet = UNet(in_channels=4, clip_embedding_dim=768)
|
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
Loading…
Reference in a new issue