mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
refactor latent_diffusion module
This commit is contained in:
parent
3ee0ccccdc
commit
92a21bc21e
|
@ -249,7 +249,7 @@ lora_weights.patch(sd15, scale=1.0)
|
|||
prompt = "a cute cat"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(30)
|
||||
|
||||
|
|
|
@ -6,19 +6,19 @@ from refiners.fluxion.utils import (
|
|||
convert_state_dict,
|
||||
save_to_safetensors,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion import UNet
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(controlnet_src: ControlNetModel) -> dict[str, torch.Tensor]:
|
||||
controlnet = Controlnet(name="mycn")
|
||||
controlnet = SD1Controlnet(name="mycn")
|
||||
|
||||
condition = torch.randn(1, 3, 512, 512)
|
||||
controlnet.set_controlnet_condition(condition=condition)
|
||||
|
||||
unet = UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet.insert(index=0, module=controlnet)
|
||||
clip_text_embedding = torch.rand(1, 77, 768)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.nn import Parameter as TorchParameter
|
|||
import refiners.fluxion.layers as fl
|
||||
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, apply_loras_to_target
|
||||
from refiners.adapters.lora import Lora
|
||||
from refiners.fluxion.utils import create_state_dict_mapping
|
||||
|
@ -37,7 +37,7 @@ def process(source: str, base_model: str, output_file: str) -> None:
|
|||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=base_model) # type: ignore
|
||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||
|
||||
refiners_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||
refiners_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
target = LoraTarget.CrossAttention
|
||||
metadata = {"unet_targets": "CrossAttentionBlock2d"}
|
||||
rank = diffusers_state_dict[
|
||||
|
|
|
@ -5,7 +5,7 @@ from refiners.fluxion.utils import (
|
|||
)
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget
|
||||
from refiners.fluxion.layers.module import Module
|
||||
import refiners.fluxion.layers as fl
|
||||
|
@ -19,7 +19,7 @@ from transformers.models.clip.modeling_clip import CLIPTextModel # type: ignore
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: UNet) -> dict[str, str] | None:
|
||||
def create_unet_mapping(src_model: UNet2DConditionModel, dst_model: SD1UNet) -> dict[str, str] | None:
|
||||
x = torch.randn(1, 4, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
clip_text_embeddings = torch.randn(1, 77, 768)
|
||||
|
@ -79,7 +79,7 @@ def main() -> None:
|
|||
match meta_key:
|
||||
case "unet_targets":
|
||||
src_model = diffusers_sd.unet # type: ignore
|
||||
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
create_mapping = create_unet_mapping
|
||||
key_prefix = "unet."
|
||||
lora_prefix = "lora_unet_"
|
||||
|
|
|
@ -5,12 +5,12 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
|||
from diffusers import StableDiffusionInpaintPipeline # type: ignore
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||
dst_model = UNet(in_channels=9, clip_embedding_dim=768)
|
||||
dst_model = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
|
||||
x = torch.randn(1, 9, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
|
|
|
@ -5,12 +5,12 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
|||
from diffusers import DiffusionPipeline # type: ignore
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(src_model: UNet2DConditionModel) -> dict[str, torch.Tensor]:
|
||||
dst_model = UNet(in_channels=4, clip_embedding_dim=768)
|
||||
dst_model = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
|
||||
x = torch.randn(1, 4, 32, 32)
|
||||
timestep = torch.tensor(data=[0])
|
||||
|
|
|
@ -6,7 +6,7 @@ from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict
|
|||
from diffusers import DiffusionPipeline # type: ignore
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel # type: ignore
|
||||
|
||||
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from torch import Tensor, arange, device as Device, dtype as DType
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
import refiners.foundationals.latent_diffusion.model as ldm
|
||||
|
||||
|
||||
class TokenEncoder(fl.Embedding):
|
||||
|
@ -121,7 +122,7 @@ class TransformerLayer(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class CLIPTextEncoder(fl.Chain):
|
||||
class CLIPTextEncoder(fl.Chain, ldm.TextEncoderInterface):
|
||||
structural_attrs = [
|
||||
"embedding_dim",
|
||||
"max_sequence_length",
|
||||
|
@ -189,10 +190,6 @@ class CLIPTextEncoder(fl.Chain):
|
|||
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
|
||||
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())
|
||||
|
||||
@property
|
||||
def unconditional_text_embedding(self) -> Tensor:
|
||||
return self("")
|
||||
|
||||
|
||||
class CLIPTextEncoderL(CLIPTextEncoder):
|
||||
"""
|
||||
|
|
|
@ -1,201 +1,37 @@
|
|||
from typing import TypeVar
|
||||
from torch import cat, float32, randn, tensor, device as Device, dtype as DType, Size, Tensor
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, interpolate
|
||||
from refiners.fluxion.layers.module import Module
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||
LatentDiffusionAutoencoder,
|
||||
)
|
||||
from refiners.foundationals.clip.text_encoder import (
|
||||
CLIPTextEncoder,
|
||||
CLIPTextEncoderL,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||
SD1UNet,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import (
|
||||
SD1Controlnet,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import (
|
||||
SDXLUNet,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import (
|
||||
DoubleTextEncoder,
|
||||
)
|
||||
|
||||
|
||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||
|
||||
__all__ = [
|
||||
"LatentDiffusionModel",
|
||||
"UNet",
|
||||
"StableDiffusion_1",
|
||||
"StableDiffusion_1_Inpainting",
|
||||
"SD1UNet",
|
||||
"SD1Controlnet",
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
"DPMSolver",
|
||||
"Scheduler",
|
||||
"CLIPTextEncoder",
|
||||
"CLIPTextEncoderL",
|
||||
"LatentDiffusionAutoencoder",
|
||||
]
|
||||
|
||||
|
||||
class LatentDiffusionModel(Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet,
|
||||
lda: LatentDiffusionAutoencoder,
|
||||
clip_text_encoder: CLIPTextEncoder,
|
||||
scheduler: Scheduler,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.device: Device = device if isinstance(device, Device) else Device(device)
|
||||
self.dtype = dtype
|
||||
self.unet = unet.to(self.device, dtype=self.dtype)
|
||||
self.lda = lda.to(self.device, dtype=self.dtype)
|
||||
self.clip_text_encoder = clip_text_encoder.to(self.device, dtype=self.dtype)
|
||||
self.scheduler = scheduler.to(self.device, dtype=self.dtype)
|
||||
|
||||
def set_num_inference_steps(self, num_inference_steps: int):
|
||||
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||
self.scheduler = self.scheduler.__class__(
|
||||
num_inference_steps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def init_latents(
|
||||
self,
|
||||
size: tuple[int, int],
|
||||
init_image: Image.Image | None = None,
|
||||
first_step: int = 0,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if noise is None:
|
||||
height, width = size
|
||||
noise = randn(1, 4, height // 8, width // 8, device=self.device)
|
||||
assert list(noise.shape[2:]) == [
|
||||
size[0] // 8,
|
||||
size[1] // 8,
|
||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||
if init_image is None:
|
||||
return noise
|
||||
encoded_image = self.lda.encode_image(init_image.resize(size))
|
||||
return self.scheduler.add_noise(encoded_image, noise, self.steps[first_step])
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
return self.scheduler.steps
|
||||
|
||||
@property
|
||||
def timestep_embeddings(self) -> Tensor:
|
||||
return self.timestep_encoder(self.scheduler.timesteps)
|
||||
|
||||
@property
|
||||
def unconditional_clip_text_embeddings(self) -> Tensor:
|
||||
return self.clip_text_encoder.unconditional_text_embedding
|
||||
|
||||
def compute_text_embedding(self, text: str) -> Tensor:
|
||||
return self.clip_text_encoder(text)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
clip_text_embedding: Tensor,
|
||||
negative_clip_text_embedding: Tensor | None = None,
|
||||
condition_scale: float = 7.5,
|
||||
) -> Tensor:
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(0)
|
||||
self.unet.set_timestep(timestep)
|
||||
|
||||
negative_clip_text_embedding = (
|
||||
self.clip_text_encoder.unconditional_text_embedding
|
||||
if negative_clip_text_embedding is None
|
||||
else negative_clip_text_embedding
|
||||
)
|
||||
|
||||
clip_text_embeddings = cat((negative_clip_text_embedding, clip_text_embedding))
|
||||
|
||||
self.unet.set_clip_text_embedding(clip_text_embeddings)
|
||||
latents = cat((x, x)) # for classifier-free guidance
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
||||
# classifier-free guidance
|
||||
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||
return self.scheduler(x, noise=noise, step=step)
|
||||
|
||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||
return self.__class__(
|
||||
unet=self.unet.structural_copy(),
|
||||
lda=self.lda.structural_copy(),
|
||||
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||
scheduler=self.scheduler,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusion_1(LatentDiffusionModel):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet | None = None,
|
||||
lda: LatentDiffusionAutoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
):
|
||||
unet = unet or UNet(in_channels=4, clip_embedding_dim=768)
|
||||
lda = lda or LatentDiffusionAutoencoder()
|
||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||
|
||||
super().__init__(
|
||||
unet,
|
||||
lda,
|
||||
clip_text_encoder=clip_text_encoder,
|
||||
scheduler=scheduler,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
|
||||
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet | None = None,
|
||||
lda: LatentDiffusionAutoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = float32,
|
||||
):
|
||||
self.mask_latents: Tensor | None = None
|
||||
self.target_image_latents: Tensor | None = None
|
||||
super().__init__(unet, lda, clip_text_encoder, scheduler, device, dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
clip_text_embedding: Tensor,
|
||||
negative_clip_text_embedding: Tensor | None = None,
|
||||
condition_scale: float = 7.5,
|
||||
):
|
||||
assert self.mask_latents is not None
|
||||
assert self.target_image_latents is not None
|
||||
x = cat((x, self.mask_latents, self.target_image_latents), dim=1)
|
||||
return super().forward(x, step, clip_text_embedding, negative_clip_text_embedding, condition_scale)
|
||||
|
||||
def set_inpainting_conditions(
|
||||
self,
|
||||
target_image: Image.Image,
|
||||
mask: Image.Image,
|
||||
latents_size: tuple[int, int] = (64, 64),
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
target_image = target_image.convert("RGB")
|
||||
mask = mask.convert("L")
|
||||
|
||||
mask_tensor = tensor(np.array(mask).astype(np.float32) / 255.0).to(self.device)
|
||||
mask_tensor = (mask_tensor > 0.5).unsqueeze(0).unsqueeze(0).to(dtype=self.dtype)
|
||||
self.mask_latents = interpolate(mask_tensor, Size(latents_size))
|
||||
|
||||
init_image_tensor = image_to_tensor(target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||
self.target_image_latents = self.lda.encode(masked_init_image)
|
||||
|
||||
return self.mask_latents, self.target_image_latents
|
||||
|
|
|
@ -9,7 +9,7 @@ from refiners.adapters.lora import LoraAdapter, load_lora_weights
|
|||
from refiners.foundationals.clip.text_encoder import FeedForward, TransformerLayer
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1
|
||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import load_from_safetensors, load_metadata_from_safetensors
|
||||
|
||||
|
@ -74,7 +74,7 @@ class LoraWeights:
|
|||
match meta_key:
|
||||
case "unet_targets":
|
||||
# TODO: support this transparently
|
||||
if any([isinstance(module, Controlnet) for module in sd.unet]):
|
||||
if any([isinstance(module, SD1Controlnet) for module in sd.unet]):
|
||||
raise NotImplementedError("Cannot patch a UNet which already contains a Controlnet adapter")
|
||||
model = sd.unet
|
||||
key_prefix = "unet."
|
||||
|
|
117
src/refiners/foundationals/latent_diffusion/model.py
Normal file
117
src/refiners/foundationals/latent_diffusion/model.py
Normal file
|
@ -0,0 +1,117 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Protocol, TypeVar
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from PIL import Image
|
||||
import torch
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
|
||||
T = TypeVar("T", bound="fl.Module")
|
||||
|
||||
|
||||
class UNetInterface(Protocol):
|
||||
def set_timestep(self, timestep: Tensor) -> None:
|
||||
...
|
||||
|
||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||
...
|
||||
|
||||
def __call__(self, x: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
class TextEncoderInterface(Protocol):
|
||||
def __call__(self, text: str) -> Tensor | tuple[Tensor, Tensor]:
|
||||
...
|
||||
|
||||
|
||||
TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")
|
||||
|
||||
|
||||
class LatentDiffusionModel(fl.Module, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNetInterface,
|
||||
lda: LatentDiffusionAutoencoder,
|
||||
clip_text_encoder: TextEncoderInterface,
|
||||
scheduler: Scheduler,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device: Device = device if isinstance(device, Device) else Device(device=device)
|
||||
self.dtype = dtype
|
||||
assert isinstance(unet, fl.Module)
|
||||
self.unet = unet.to(device=self.device, dtype=self.dtype)
|
||||
self.lda = lda.to(device=self.device, dtype=self.dtype)
|
||||
assert isinstance(clip_text_encoder, fl.Module)
|
||||
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def set_num_inference_steps(self, num_inference_steps: int) -> None:
|
||||
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
|
||||
final_diffusion_rate = self.scheduler.final_diffusion_rate
|
||||
device, dtype = self.scheduler.device, self.scheduler.dtype
|
||||
self.scheduler = self.scheduler.__class__(
|
||||
num_inference_steps,
|
||||
initial_diffusion_rate=initial_diffusion_rate,
|
||||
final_diffusion_rate=final_diffusion_rate,
|
||||
).to(device=device, dtype=dtype)
|
||||
|
||||
def init_latents(
|
||||
self,
|
||||
size: tuple[int, int],
|
||||
init_image: Image.Image | None = None,
|
||||
first_step: int = 0,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
if noise is None:
|
||||
height, width = size
|
||||
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
||||
assert list(noise.shape[2:]) == [
|
||||
size[0] // 8,
|
||||
size[1] // 8,
|
||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||
if init_image is None:
|
||||
return noise
|
||||
encoded_image = self.lda.encode_image(image=init_image.resize(size=size))
|
||||
return self.scheduler.add_noise(x=encoded_image, noise=noise, step=self.steps[first_step])
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
return self.scheduler.steps
|
||||
|
||||
@abstractmethod
|
||||
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *args: Tensor) -> None:
|
||||
...
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
clip_text_embedding: Tensor,
|
||||
*args: Tensor,
|
||||
condition_scale: float = 7.5,
|
||||
) -> Tensor:
|
||||
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
|
||||
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, *args)
|
||||
|
||||
latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
|
||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||
|
||||
# classifier-free guidance
|
||||
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||
return self.scheduler(x, noise=noise, step=step)
|
||||
|
||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||
return self.__class__(
|
||||
unet=self.unet.structural_copy(),
|
||||
lda=self.lda.structural_copy(),
|
||||
clip_text_encoder=self.clip_text_encoder.structural_copy(),
|
||||
scheduler=self.scheduler,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
|
@ -10,7 +10,7 @@ from refiners.fluxion.layers import (
|
|||
Parallel,
|
||||
)
|
||||
from refiners.adapters.adapter import Adapter
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||
from torch import Tensor
|
||||
|
||||
|
@ -58,7 +58,7 @@ class ReferenceOnlyControlAdapter(Chain, Adapter[SelfAttention]):
|
|||
class SelfAttentionInjection(Passthrough):
|
||||
# TODO: Does not support batching yet. Assumes concatenated inputs for classifier-free guidance
|
||||
|
||||
def __init__(self, unet: UNet, style_cfg: float = 0.5) -> None:
|
||||
def __init__(self, unet: SD1UNet, style_cfg: float = 0.5) -> None:
|
||||
# the style_cfg is the weight of the guide in unconditionned diffusion.
|
||||
# This value is recommended to be 0.5 on the sdwebui repo.
|
||||
self.style_cfg = style_cfg
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
)
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||
|
||||
__all__ = [
|
||||
"StableDiffusion_1",
|
||||
"StableDiffusion_1_Inpainting",
|
||||
"SD1UNet",
|
||||
"SD1Controlnet",
|
||||
]
|
|
@ -1,6 +1,11 @@
|
|||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers import Chain, Conv2d, SiLU, Lambda, Passthrough, UseContext, Sum, Identity
|
||||
from refiners.foundationals.latent_diffusion.unet import DownBlocks, MiddleBlock, ResidualBlock, TimestepEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||
DownBlocks,
|
||||
MiddleBlock,
|
||||
ResidualBlock,
|
||||
TimestepEncoder,
|
||||
)
|
||||
from refiners.adapters.range_adapter import RangeAdapter2d
|
||||
from typing import cast, Iterable
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
@ -64,7 +69,7 @@ class ConditionEncoder(Chain):
|
|||
)
|
||||
|
||||
|
||||
class Controlnet(Passthrough):
|
||||
class SD1Controlnet(Passthrough):
|
||||
structural_attrs = ["name", "scale"]
|
||||
|
||||
def __init__(self, name: str, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
from refiners.fluxion.utils import image_to_tensor, interpolate
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from torch import device as Device, dtype as DType, Tensor
|
||||
|
||||
|
||||
class StableDiffusion_1(LatentDiffusionModel):
|
||||
unet: SD1UNet
|
||||
clip_text_encoder: CLIPTextEncoderL
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: SD1UNet | None = None,
|
||||
lda: LatentDiffusionAutoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
unet = unet or SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
lda = lda or LatentDiffusionAutoencoder()
|
||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||
scheduler = scheduler or DPMSolver(num_inference_steps=30)
|
||||
|
||||
super().__init__(
|
||||
unet=unet,
|
||||
lda=lda,
|
||||
clip_text_encoder=clip_text_encoder,
|
||||
scheduler=scheduler,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
|
||||
conditional_embedding = self.clip_text_encoder(text)
|
||||
if text == negative_text:
|
||||
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
|
||||
|
||||
negative_embedding = self.clip_text_encoder(negative_text or "")
|
||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
||||
|
||||
def set_unet_context(self, timestep: Tensor, clip_text_embedding: Tensor, *_: Tensor) -> None:
|
||||
self.unet.set_timestep(timestep=timestep)
|
||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
|
||||
|
||||
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||
def __init__(
|
||||
self,
|
||||
unet: SD1UNet | None = None,
|
||||
lda: LatentDiffusionAutoencoder | None = None,
|
||||
clip_text_encoder: CLIPTextEncoderL | None = None,
|
||||
scheduler: Scheduler | None = None,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
self.mask_latents: Tensor | None = None
|
||||
self.target_image_latents: Tensor | None = None
|
||||
super().__init__(
|
||||
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
step: int,
|
||||
clip_text_embedding: Tensor,
|
||||
*_: Tensor,
|
||||
condition_scale: float = 7.5,
|
||||
) -> Tensor:
|
||||
assert self.mask_latents is not None
|
||||
assert self.target_image_latents is not None
|
||||
x = torch.cat(tensors=(x, self.mask_latents, self.target_image_latents), dim=1)
|
||||
return super().forward(
|
||||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=condition_scale,
|
||||
)
|
||||
|
||||
def set_inpainting_conditions(
|
||||
self,
|
||||
target_image: Image.Image,
|
||||
mask: Image.Image,
|
||||
latents_size: tuple[int, int] = (64, 64),
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
target_image = target_image.convert(mode="RGB")
|
||||
mask = mask.convert(mode="L")
|
||||
|
||||
mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
|
||||
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
|
||||
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
|
||||
|
||||
init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
|
||||
masked_init_image = init_image_tensor * (1 - mask_tensor)
|
||||
self.target_image_latents = self.lda.encode(x=masked_init_image)
|
||||
|
||||
return self.mask_latents, self.target_image_latents
|
|
@ -7,6 +7,7 @@ import refiners.fluxion.layers as fl
|
|||
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.adapters.range_adapter import RangeEncoder, RangeAdapter2d
|
||||
import refiners.foundationals.latent_diffusion.model as ldm
|
||||
|
||||
|
||||
class TimestepEncoder(fl.Passthrough):
|
||||
|
@ -242,7 +243,7 @@ class ResidualConcatenator(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class UNet(fl.Chain):
|
||||
class SD1UNet(fl.Chain, ldm.UNetInterface):
|
||||
structural_attrs = ["in_channels", "clip_embedding_dim"]
|
||||
|
||||
def __init__(
|
|
@ -0,0 +1,8 @@
|
|||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SDXLUNet",
|
||||
"DoubleTextEncoder",
|
||||
]
|
|
@ -7,6 +7,7 @@ from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextE
|
|||
from jaxtyping import Float
|
||||
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
from refiners.foundationals.latent_diffusion.model import TextEncoderInterface
|
||||
|
||||
|
||||
class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
||||
|
@ -59,7 +60,7 @@ class TextEncoderWithPooling(fl.Chain, Adapter[CLIPTextEncoderG]):
|
|||
return x[:, end_of_text_index[0], :]
|
||||
|
||||
|
||||
class DoubleTextEncoder(fl.Chain):
|
||||
class DoubleTextEncoder(fl.Chain, TextEncoderInterface):
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder_l: CLIPTextEncoderL | None = None,
|
|
@ -3,7 +3,12 @@ from torch import Tensor, device as Device, dtype as DType
|
|||
from refiners.fluxion.context import Contexts
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
|
||||
from refiners.foundationals.latent_diffusion.unet import ResidualAccumulator, ResidualBlock, ResidualConcatenator
|
||||
from refiners.foundationals.latent_diffusion.model import UNetInterface
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import (
|
||||
ResidualAccumulator,
|
||||
ResidualBlock,
|
||||
ResidualConcatenator,
|
||||
)
|
||||
from refiners.adapters.range_adapter import RangeAdapter2d, RangeEncoder, compute_sinusoidal_embedding
|
||||
|
||||
|
||||
|
@ -242,7 +247,7 @@ class OutputBlock(fl.Chain):
|
|||
)
|
||||
|
||||
|
||||
class SDXLUNet(fl.Chain):
|
||||
class SDXLUNet(fl.Chain, UNetInterface):
|
||||
structural_attrs = ["in_channels"]
|
||||
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
|
@ -5,7 +5,7 @@ from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
|
|||
from torch import device as Device, Tensor, randn, dtype as DType, Generator, cat
|
||||
from loguru import logger
|
||||
from torch.utils.data import Dataset
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
from torchvision.transforms import RandomCrop # type: ignore
|
||||
|
@ -103,9 +103,9 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
|
|||
|
||||
class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||
@cached_property
|
||||
def unet(self) -> UNet:
|
||||
def unet(self) -> SD1UNet:
|
||||
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
|
||||
return UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
||||
return SD1UNet(in_channels=4, clip_embedding_dim=768, device=self.device).to(device=self.device)
|
||||
|
||||
@cached_property
|
||||
def text_encoder(self) -> CLIPTextEncoderL:
|
||||
|
@ -171,14 +171,12 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
for i in range(num_images_per_prompt):
|
||||
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
|
||||
x = randn(1, 4, 64, 64, device=self.device)
|
||||
clip_text_embedding = sd.compute_text_embedding(text=prompt).to(device=self.device)
|
||||
negative_clip_text_embedding = sd.compute_text_embedding(text="").to(device=self.device)
|
||||
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device)
|
||||
for step in sd.steps:
|
||||
x = sd(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
)
|
||||
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))
|
||||
images[prompt] = canvas_image
|
||||
|
|
|
@ -9,8 +9,8 @@ from pathlib import Path
|
|||
|
||||
from refiners.fluxion.utils import load_from_safetensors, image_to_tensor, manual_seed
|
||||
from refiners.foundationals.latent_diffusion import StableDiffusion_1, StableDiffusion_1_Inpainting
|
||||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.controlnet import Controlnet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1Controlnet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraWeights
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
from refiners.foundationals.latent_diffusion.self_attention_injection import SelfAttentionInjection
|
||||
|
@ -196,7 +196,7 @@ def sd15_inpainting(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
unet = UNet(in_channels=9, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
|
@ -214,7 +214,7 @@ def sd15_inpainting_float16(
|
|||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
unet = UNet(in_channels=9, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=9, clip_embedding_dim=768)
|
||||
sd15 = StableDiffusion_1_Inpainting(unet=unet, device=test_device, dtype=torch.float16)
|
||||
|
||||
sd15.clip_text_encoder.load_state_dict(load_from_safetensors(text_encoder_weights))
|
||||
|
@ -253,8 +253,7 @@ def test_diffusion_std_random_init(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -267,7 +266,6 @@ def test_diffusion_std_random_init(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -286,11 +284,9 @@ def test_diffusion_std_random_init_float16(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
assert clip_text_embedding.dtype == torch.float16
|
||||
assert negative_clip_text_embedding.dtype == torch.float16
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -303,7 +299,6 @@ def test_diffusion_std_random_init_float16(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -325,8 +320,7 @@ def test_diffusion_std_init_image(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -339,7 +333,6 @@ def test_diffusion_std_init_image(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -362,8 +355,7 @@ def test_diffusion_inpainting(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||
|
@ -377,7 +369,6 @@ def test_diffusion_inpainting(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -401,11 +392,9 @@ def test_diffusion_inpainting_float16(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
assert clip_text_embedding.dtype == torch.float16
|
||||
assert negative_clip_text_embedding.dtype == torch.float16
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||
|
@ -419,7 +408,6 @@ def test_diffusion_inpainting_float16(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -447,13 +435,12 @@ def test_diffusion_controlnet(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||
controlnet = Controlnet(name=cn_name, device=test_device)
|
||||
controlnet = SD1Controlnet(name=cn_name, device=test_device)
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.set_scale(0.5)
|
||||
sd15.unet.insert(0, controlnet)
|
||||
|
@ -470,7 +457,6 @@ def test_diffusion_controlnet(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -498,13 +484,12 @@ def test_diffusion_controlnet_structural_copy(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||
controlnet = Controlnet(name=cn_name, device=test_device)
|
||||
controlnet = SD1Controlnet(name=cn_name, device=test_device)
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.set_scale(0.5)
|
||||
sd15.unet.insert(0, controlnet)
|
||||
|
@ -521,7 +506,6 @@ def test_diffusion_controlnet_structural_copy(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -548,13 +532,12 @@ def test_diffusion_controlnet_float16(
|
|||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
negative_clip_text_embedding = sd15.compute_text_embedding(negative_prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
controlnet_state_dict = load_from_safetensors(cn_weights_path)
|
||||
controlnet = Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
|
||||
controlnet = SD1Controlnet(name=cn_name, device=test_device, dtype=torch.float16)
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.set_scale(0.5)
|
||||
sd15.unet.insert(0, controlnet)
|
||||
|
@ -571,7 +554,6 @@ def test_diffusion_controlnet_float16(
|
|||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
negative_clip_text_embedding=negative_clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
@ -597,7 +579,7 @@ def test_diffusion_lora(
|
|||
prompt = "a cute cat"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
|
||||
|
@ -631,7 +613,7 @@ def test_diffusion_refonly(
|
|||
prompt = "Chicken"
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = SelfAttentionInjection(sd15.unet)
|
||||
sai.inject()
|
||||
|
@ -673,7 +655,7 @@ def test_diffusion_inpainting_refonly(
|
|||
prompt = "" # unconditional
|
||||
|
||||
with torch.no_grad():
|
||||
clip_text_embedding = sd15.compute_text_embedding(prompt)
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||
|
||||
sai = SelfAttentionInjection(sd15.unet)
|
||||
sai.inject()
|
||||
|
|
|
@ -8,7 +8,7 @@ from torch import Tensor
|
|||
from refiners.fluxion.utils import manual_seed
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderG, CLIPTextEncoderL
|
||||
from refiners.foundationals.latent_diffusion.sdxl_text_encoder import DoubleTextEncoder
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
||||
|
||||
class DiffusersSDXL(Protocol):
|
||||
|
|
|
@ -4,7 +4,7 @@ from warnings import warn
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion.sdxl_unet import SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
|
||||
from refiners.fluxion.utils import compare_models
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from refiners.foundationals.latent_diffusion.unet import UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
|
||||
from refiners.fluxion import manual_seed
|
||||
import torch
|
||||
|
||||
|
@ -9,7 +9,7 @@ def test_unet_context_flush():
|
|||
timestep = torch.randint(0, 999, size=(1, 1))
|
||||
x = torch.randn(1, 4, 32, 32)
|
||||
|
||||
unet = UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet = SD1UNet(in_channels=4, clip_embedding_dim=768)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
Loading…
Reference in a new issue