create SD1.5 MultiUpscaler pipeline

Co-authored-by: limiteinductive <benjamin@lagon.tech>
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Laurent 2024-07-11 13:04:59 +00:00 committed by Laureηt
parent 66cd0d57a1
commit f3b5c8d3e1

View file

@ -0,0 +1,372 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Sequence
import torch
from PIL import Image
from torch import Tensor
from typing_extensions import TypeVar
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
)
Name = str
@dataclass(kw_only=True)
class UpscalerCheckpoints:
"""
Checkpoints for the multi upscaler.
Attributes:
unet: The path to the Stable Diffusion 1 UNet checkpoint.
clip_text_encoder: The path to the CLIP text encoder checkpoint.
lda: The path to the LDA checkpoint.
controlnet_tile: The path to the controlnet tile checkpoint.
negative_embedding: The path to the negative embedding checkpoint.
negative_embedding_key: The key for the negative embedding. If the negative embedding is a dictionary, this
key is used to access the negative embedding. You can use a dot-separated path to access nested dictionaries.
loras: A dictionary of LORAs to load. The key is the name of the LORA and the value is the path to the LORA
checkpoint.
"""
unet: Path
clip_text_encoder: Path
lda: Path
controlnet_tile: Path
negative_embedding: Path | None = None
negative_embedding_key: str | None = None
loras: dict[Name, Path] | None = None
@dataclass(kw_only=True)
class UpscalerTarget(DiffusionTarget):
clip_text_embedding: Tensor
controlnet_condition: Tensor
condition_scale: float = 7.0
T = TypeVar("T", bound=DiffusionTarget)
class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
def __init__(self, checkpoints: UpscalerCheckpoints, device: torch.device, dtype: torch.dtype) -> None:
self.device = device
self.dtype = dtype
self.sd = self.load_stable_diffusion(checkpoints)
self.manager = self.load_loras(checkpoints.loras)
self.controlnet = self.load_controlnet(checkpoints)
self.negative_embedding_token = self.load_negative_embedding(
checkpoints.negative_embedding, checkpoints.negative_embedding_key
)
@abstractmethod
def compute_targets(
self,
image: Image.Image,
latent_size: Size,
tile_size: Size,
num_inference_steps: int,
first_step: int,
condition_scale: float,
clip_text_embedding: torch.Tensor,
) -> Sequence[T]: ...
@abstractmethod
def diffuse_target(self, x: Tensor, step: int, target: T) -> Tensor: ...
def load_stable_diffusion(self, checkpoints: UpscalerCheckpoints) -> StableDiffusion_1:
sd = StableDiffusion_1(device=self.device, dtype=self.dtype)
sd.unet.load_from_safetensors(checkpoints.unet)
sd.clip_text_encoder.load_from_safetensors(checkpoints.clip_text_encoder)
sd.lda.load_from_safetensors(checkpoints.lda)
return sd
def load_controlnet(self, checkpoints: UpscalerCheckpoints) -> SD1ControlnetAdapter:
return SD1ControlnetAdapter(
target=self.sd.unet,
name="tile",
weights=load_from_safetensors(checkpoints.controlnet_tile),
).inject()
def load_loras(self, loras: dict[Name, Path] | None) -> SDLoraManager | None:
if loras is None:
return
manager = SDLoraManager(self.sd)
for name, path in loras.items():
manager.add_loras(name, tensors=load_from_safetensors(path))
return manager
def load_negative_embedding(self, path: Path | None, key: str | None) -> str:
if path is None:
return ""
embeddings: Tensor | dict[str, Any] = torch.load(path, weights_only=True) # type: ignore
if isinstance(embeddings, dict):
assert key is not None, "Key must be provided to access the negative embedding."
key_sequence = key.split(".")
for key in key_sequence:
assert (
key in embeddings
), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}"
embeddings = embeddings[key]
assert isinstance(
embeddings, torch.Tensor
), f"The negative embedding must be a tensor, found {type(embeddings)}."
assert embeddings.ndim == 2, f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor."
extender = ConceptExtender(self.sd.clip_text_encoder)
negative_embedding_token = ", "
for i, embedding in enumerate(embeddings):
extender.add_concept(token=f"<{i}>", embedding=embedding)
negative_embedding_token += f"<{i}> "
extender.inject()
return negative_embedding_token
@no_grad()
def compute_clip_text_embedding(self, prompt: str, negative_prompt: str, offload_to_cpu: bool = True) -> Tensor:
"""Compute the CLIP text embedding for the prompt and negative prompt.
Args:
prompt: The prompt to use for the upscaling.
negative_prompt: The negative prompt to use for the upscaling.
offload_to_cpu: Whether to offload the model to the CPU after computing the embedding.
"""
if self.negative_embedding_token:
negative_prompt += self.negative_embedding_token
self.sd.clip_text_encoder.to(device=self.device, dtype=self.dtype)
clip_text_embedding = self.sd.compute_clip_text_embedding(
text=prompt,
negative_text=negative_prompt,
)
if offload_to_cpu:
self.sd.clip_text_encoder.to(torch.device("cpu"))
return clip_text_embedding
def diffuse_upscaler_target(self, x: Tensor, step: int, target: UpscalerTarget) -> Tensor:
self.sd.solver = target.solver
self.controlnet.set_controlnet_condition(target.controlnet_condition)
return self.sd(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
condition_scale=target.condition_scale,
)
@staticmethod
def resize_modulo_8(image: Image.Image, size: int = 768, on_short: bool = True) -> Image.Image:
"""
Resize an image respecting the aspect ratio and ensuring the size is a multiple of 8.
The `on_short` parameter determines whether the resizing is based on the shortest side.
"""
assert size % 8 == 0, "Size must be a multiple of 8 because this is the latent compression size."
side_size = min(image.size) if on_short else max(image.size)
scale = size / (side_size * 8)
new_size = (int(image.width * scale) * 8, int(image.height * scale) * 8)
return image.resize(new_size, resample=Image.Resampling.LANCZOS) # type: ignore
@no_grad()
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
"""
Pre-upscale an image before the actual upscaling process.
You can override this method to implement custom pre-upscaling logic like using a ESRGAN model like in the
original implementation.
"""
return image.resize(
(int(image.width * upscale_factor), int(image.height * upscale_factor)),
resample=Image.Resampling.LANCZOS, # type: ignore
)
def compute_upscaler_targets(
self,
image: Image.Image,
latent_size: Size,
tile_size: Size,
num_inference_steps: int,
first_step: int,
condition_scale: float,
clip_text_embedding: torch.Tensor,
) -> Sequence[UpscalerTarget]:
tiles = MultiDiffusion.generate_latent_tiles(size=latent_size, tile_size=tile_size, min_overlap=8)
targets: Sequence[UpscalerTarget] = []
for tile in tiles:
pixel_box = (tile.left * 8, tile.top * 8, tile.right * 8, tile.bottom * 8)
pixel_tile = image.crop(pixel_box)
solver = self.sd.solver.rebuild(num_inference_steps=num_inference_steps, first_inference_step=first_step)
target = UpscalerTarget(
tile=tile,
solver=solver,
start_step=first_step,
condition_scale=condition_scale,
controlnet_condition=image_to_tensor(pixel_tile, device=self.device, dtype=self.dtype),
clip_text_embedding=clip_text_embedding,
)
targets.append(target)
return targets
def diffuse_targets(
self,
targets: Sequence[T],
image: Image.Image,
latent_size: Size,
first_step: int,
autoencoder_tile_length: int,
) -> Image.Image:
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
latents = self.sd.lda.tiled_image_to_latents(image)
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
for step in self.sd.steps:
x = self(x, noise=noise, step=step, targets=targets)
return self.sd.lda.tiled_latents_to_image(x)
@no_grad()
def upscale(
self,
image: Image.Image,
prompt: str = "masterpiece, best quality, highres",
negative_prompt: str = "worst quality, low quality, normal quality",
upscale_factor: float = 2,
downscale_size: int = 768,
tile_size: tuple[int, int] = (144, 112),
denoise_strength: float = 0.35,
condition_scale: float = 6,
controlnet_scale: float = 0.6,
controlnet_scale_decay: float = 0.825,
loras_scale: dict[Name, float] | None = None,
solver_type: type[Solver] = DPMSolver,
num_inference_steps: int = 18,
autoencoder_tile_length: int = 1024,
seed: int = 37,
) -> Image.Image:
"""
Upscale an image using the multi upscaler.
Default settings follow closely to the original implementation https://github.com/philz1337x/clarity-upscaler/
Args:
image: The image to upscale.
prompt: The prompt to use for the upscaling.
negative_prompt: The negative prompt to use for the upscaling. Original default has a weight of 2.0, but
using prompt weighting is no supported yet in Refiners.
upscale_factor: The factor to upscale the image by.
downscale_size: The size to downscale the image along is short side to before upscaling. Must be a
multiple of 8 because of latent compression.
tile_size: The size (H, W) of the tiles to use for latent diffusion. The smaller the tile size, the more "fractal"
the upscaling will be.
denoise_strength: The strength of the denoising. A value of 0.0 means no denoising (so nothing happens),
while a value of 1.0 means full denoising and maximum creativity.
condition_scale: The scale of the condition. Higher values will create images with more contrast. This
parameter is called "dynamic" or "HDR" in the original implementation.
controlnet_scale: The scale of the Tile Controlnet. This parameter is called "resemblance" in the original
implementation.
controlnet_scale_decay: Applies an exponential decay to the controlnet scale on the blocks of the UNet. This
has the effect of diminishing the controlnet scale in a subtle way. The default value is 0.825
corresponding to the "Prompt is more important" parameter in the original implementation.
loras_scale: The scale of the LORAs. This is a dictionary where the key is the name of the LORA and the value
is the scale.
solver_type: The type of solver to use for the latent diffusion. The default is the DPM solver.
num_inference_steps: The number of inference steps to use for the latent diffusion. This is a trade-off
between quality and speed.
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
lowering it can reduce GPU memory usage (but increase computation time).
seed: The seed to use for the random number generator.
"""
manual_seed(seed)
# update controlnet scale
self.controlnet.scale = controlnet_scale
self.controlnet.scale_decay = controlnet_scale_decay
# update LoRA scales
if self.manager is not None and loras_scale is not None:
self.manager.update_scales(loras_scale)
# update the solver
first_step = int(num_inference_steps * (1 - denoise_strength))
self.sd.solver = solver_type(
num_inference_steps=num_inference_steps,
first_inference_step=first_step,
device=self.device,
dtype=self.dtype,
)
# compute clip text embedding
clip_text_embedding = self.compute_clip_text_embedding(prompt=prompt, negative_prompt=negative_prompt)
# prepare the image for the upscale
image = self.resize_modulo_8(image, size=downscale_size)
image = self.pre_upscale(image, upscale_factor=upscale_factor)
# compute the latent size and tile size
latent_size = Size(height=image.height // 8, width=image.width // 8)
tile_size = Size(height=tile_size[0], width=tile_size[1])
# split the image into tiles
targets: Sequence[DiffusionTarget] = self.compute_targets(
image=image,
latent_size=latent_size,
tile_size=tile_size,
num_inference_steps=num_inference_steps,
first_step=first_step,
condition_scale=condition_scale,
clip_text_embedding=clip_text_embedding,
)
# diffuse the tiles
return self.diffuse_targets(
targets=targets,
image=image,
latent_size=latent_size,
first_step=first_step,
autoencoder_tile_length=autoencoder_tile_length,
)
class MultiUpscaler(MultiUpscalerAbstract[UpscalerTarget]):
def diffuse_target(self, x: Tensor, step: int, target: UpscalerTarget) -> Tensor:
return self.diffuse_upscaler_target(x=x, step=step, target=target)
def compute_targets(
self,
image: Image.Image,
latent_size: Size,
tile_size: Size,
num_inference_steps: int,
first_step: int,
condition_scale: float,
clip_text_embedding: torch.Tensor,
) -> Sequence[UpscalerTarget]:
return self.compute_upscaler_targets(
image=image,
latent_size=latent_size,
tile_size=tile_size,
num_inference_steps=num_inference_steps,
first_step=first_step,
condition_scale=condition_scale,
clip_text_embedding=clip_text_embedding,
)