improve MultiDiffusion pipelines

Co-authored-by: limiteinductive <benjamin@lagon.tech>
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Laurent 2024-07-11 13:04:06 +00:00 committed by Laureηt
parent b4db08de24
commit 66cd0d57a1
3 changed files with 148 additions and 81 deletions

View file

@ -1,28 +1,75 @@
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
from typing import Generic, NamedTuple, Sequence
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from torch import Tensor
from typing_extensions import TypeVar
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
MAX_STEPS = 1000
@dataclass
class Tile(NamedTuple):
top: int
left: int
bottom: int
right: int
class Size(NamedTuple):
height: int
width: int
@dataclass(kw_only=True)
class DiffusionTarget:
size: tuple[int, int]
offset: tuple[int, int]
clip_text_embedding: Tensor
"""
Represents a target for the tiled diffusion process.
This class encapsulates the parameters and properties needed to define a specific area (target) within a larger
diffusion process, allowing for fine-grained control over different regions of the generated image.
Attributes:
tile: The tile defining the area of the target within the latent image.
solver: The solver to use for this target's diffusion process. This is useful because some solvers have an
internal state that needs to be updated during the diffusion process. Using the same solver instance for
multiple targets would interfere with this internal state.
init_latents: The initial latents for this target. If None, the target will be initialized with noise.
opacity_mask: Mask controlling the target's visibility in the final image.
If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent
which means the target has no influence.
weight: The importance of this target in the final image. Higher values increase the target's influence.
start_step: The diffusion step at which this target begins to influence the process.
end_step: The diffusion step at which this target stops influencing the process.
size: The size of the target area.
offset: The top-left offset of the target area within the latent image.
The combination of `opacity_mask` and `weight` determines the target's overall contribution to the final generated
image. The `solver` is responsible for the actual diffusion calculations for this target.
"""
tile: Tile
solver: Solver
init_latents: Tensor | None = None
mask_latent: Tensor | None = None
opacity_mask: Tensor | None = None
weight: int = 1
condition_scale: float = 7.5
start_step: int = 0
end_step: int = MAX_STEPS
@property
def size(self) -> Size:
return Size(
height=self.tile.bottom - self.tile.top,
width=self.tile.right - self.tile.left,
)
@property
def offset(self) -> tuple[int, int]:
return self.tile.top, self.tile.left
def crop(self, tensor: Tensor, /) -> Tensor:
height, width = self.size
top_offset, left_offset = self.offset
@ -35,15 +82,17 @@ class DiffusionTarget:
return tensor
T = TypeVar("T", bound=LatentDiffusionModel)
D = TypeVar("D", bound=DiffusionTarget)
T = TypeVar("T", bound=DiffusionTarget)
@dataclass
class MultiDiffusion(Generic[T, D], ABC):
ldm: T
class MultiDiffusion(ABC, Generic[T]):
"""
MultiDiffusion class for performing multi-target diffusion using tiled diffusion.
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor:
For more details, refer to the paper: [MultiDiffusion](https://arxiv.org/abs/2302.08113)
"""
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: Sequence[T]) -> Tensor:
num_updates = torch.zeros_like(input=x)
cumulative_values = torch.zeros_like(input=x)
@ -51,7 +100,7 @@ class MultiDiffusion(Generic[T, D], ABC):
match step:
case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise)
view = self.ldm.solver.add_noise(
view = target.solver.add_noise(
x=target.init_latents,
noise=noise_view,
step=step,
@ -61,44 +110,60 @@ class MultiDiffusion(Generic[T, D], ABC):
case _:
continue
view = self.diffuse_target(x=view, step=step, target=target)
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight
weight = target.weight * target.opacity_mask if target.opacity_mask is not None else target.weight
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
@abstractmethod
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ...
@property
def steps(self) -> list[int]:
return self.ldm.steps
@property
def device(self) -> Device:
return self.ldm.device
@property
def dtype(self) -> DType:
return self.ldm.dtype
# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x=x)
def latents_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latents_to_image(x=x)
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)
def diffuse_target(self, x: Tensor, step: int, target: T) -> Tensor: ...
@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
height, width = size
def generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]:
"""
Generate tiles for a latent image with the given size and tile size.
return [
(y, x)
for y in range(0, height, stride)
for x in range(0, width, stride)
if y + 64 <= height and x + 64 <= width
]
If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is
used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size
is respected as much as possible, while still covering the entire image and respecting the minimum overlap.
"""
assert (
0 <= min_overlap < min(tile_size.height, tile_size.width)
), "Overlap must be non-negative and less than the tile size"
if tile_size.width > size.width or tile_size.height > size.height:
return [Tile(top=0, left=0, bottom=size.height, right=size.width)]
tiles: list[Tile] = []
def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]:
if tile_length >= length:
return 1, 0
num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1
overlap = (num_tiles * tile_length - length) // (num_tiles - 1)
return num_tiles, overlap
num_tiles_x, overlap_x = _compute_tiles_and_overlap(
length=size.width, tile_length=tile_size.width, min_overlap=min_overlap
)
num_tiles_y, overlap_y = _compute_tiles_and_overlap(
length=size.height, tile_length=tile_size.height, min_overlap=min_overlap
)
for i in range(num_tiles_y):
for j in range(num_tiles_x):
x = j * (tile_size.width - overlap_x)
y = i * (tile_size.height - overlap_y)
# Adjust x and y coordinates to ensure full-sized tiles
if x + tile_size.width > size.width:
x = size.width - tile_size.width
if y + tile_size.height > size.height:
y = size.height - tile_size.height
tile_right = x + tile_size.width
tile_bottom = y + tile_size.height
tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right))
return tiles

View file

@ -1,41 +1,31 @@
from dataclasses import dataclass, field
from dataclasses import dataclass
from PIL import Image
from torch import Tensor
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
condition_scale=target.condition_scale,
)
@dataclass
class InpaintingDiffusionTarget(DiffusionTarget):
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
self.ldm.set_inpainting_conditions(
target_image=target.target_image,
mask=target.mask,
)
return self.ldm(
@dataclass(kw_only=True)
class SD1DiffusionTarget(DiffusionTarget):
clip_text_embedding: Tensor
condition_scale: float = 7.0
class SD1MultiDiffusion(MultiDiffusion[SD1DiffusionTarget]):
def __init__(self, sd: StableDiffusion_1) -> None:
self.sd = sd
def diffuse_target(self, x: Tensor, step: int, target: SD1DiffusionTarget) -> Tensor:
old_solver = self.sd.solver
self.sd.solver = target.solver
result = self.sd(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
condition_scale=target.condition_scale,
)
self.sd.solver = old_solver
return result

View file

@ -1,17 +1,27 @@
from dataclasses import dataclass
from torch import Tensor
from refiners.foundationals.latent_diffusion import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
class SDXLDiffusionTarget(DiffusionTarget):
@dataclass(kw_only=True)
class SDXLTarget(DiffusionTarget):
clip_text_embedding: Tensor
condition_scale: float = 5.0
pooled_text_embedding: Tensor
time_ids: Tensor
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
return self.ldm(
class SDXLMultiDiffusion(MultiDiffusion[SDXLTarget]):
def __init__(self, sd: StableDiffusion_XL) -> None:
self.sd = sd
def diffuse_target(self, x: Tensor, step: int, target: SDXLTarget) -> Tensor:
old_solver = self.sd.solver
self.sd.solver = target.solver
result = self.sd(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
@ -19,3 +29,5 @@ class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]
time_ids=target.time_ids,
condition_scale=target.condition_scale,
)
self.sd.solver = old_solver
return result