From 66cd0d57a1110f895640c7eb4c01c0e55783617a Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 11 Jul 2024 13:04:06 +0000 Subject: [PATCH] improve MultiDiffusion pipelines MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: limiteinductive Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com> --- .../latent_diffusion/multi_diffusion.py | 163 ++++++++++++------ .../stable_diffusion_1/multi_diffusion.py | 44 ++--- .../stable_diffusion_xl/multi_diffusion.py | 22 ++- 3 files changed, 148 insertions(+), 81 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py index 131357f..361f2cd 100644 --- a/src/refiners/foundationals/latent_diffusion/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/multi_diffusion.py @@ -1,28 +1,75 @@ +import math from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Generic, TypeVar +from typing import Generic, NamedTuple, Sequence import torch -from PIL import Image -from torch import Tensor, device as Device, dtype as DType +from torch import Tensor +from typing_extensions import TypeVar -from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel +from refiners.foundationals.latent_diffusion.solvers.solver import Solver MAX_STEPS = 1000 -@dataclass +class Tile(NamedTuple): + top: int + left: int + bottom: int + right: int + + +class Size(NamedTuple): + height: int + width: int + + +@dataclass(kw_only=True) class DiffusionTarget: - size: tuple[int, int] - offset: tuple[int, int] - clip_text_embedding: Tensor + """ + Represents a target for the tiled diffusion process. + + This class encapsulates the parameters and properties needed to define a specific area (target) within a larger + diffusion process, allowing for fine-grained control over different regions of the generated image. + + Attributes: + tile: The tile defining the area of the target within the latent image. + solver: The solver to use for this target's diffusion process. This is useful because some solvers have an + internal state that needs to be updated during the diffusion process. Using the same solver instance for + multiple targets would interfere with this internal state. + init_latents: The initial latents for this target. If None, the target will be initialized with noise. + opacity_mask: Mask controlling the target's visibility in the final image. + If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent + which means the target has no influence. + weight: The importance of this target in the final image. Higher values increase the target's influence. + start_step: The diffusion step at which this target begins to influence the process. + end_step: The diffusion step at which this target stops influencing the process. + size: The size of the target area. + offset: The top-left offset of the target area within the latent image. + + The combination of `opacity_mask` and `weight` determines the target's overall contribution to the final generated + image. The `solver` is responsible for the actual diffusion calculations for this target. + """ + + tile: Tile + solver: Solver init_latents: Tensor | None = None - mask_latent: Tensor | None = None + opacity_mask: Tensor | None = None weight: int = 1 - condition_scale: float = 7.5 start_step: int = 0 end_step: int = MAX_STEPS + @property + def size(self) -> Size: + return Size( + height=self.tile.bottom - self.tile.top, + width=self.tile.right - self.tile.left, + ) + + @property + def offset(self) -> tuple[int, int]: + return self.tile.top, self.tile.left + def crop(self, tensor: Tensor, /) -> Tensor: height, width = self.size top_offset, left_offset = self.offset @@ -35,15 +82,17 @@ class DiffusionTarget: return tensor -T = TypeVar("T", bound=LatentDiffusionModel) -D = TypeVar("D", bound=DiffusionTarget) +T = TypeVar("T", bound=DiffusionTarget) -@dataclass -class MultiDiffusion(Generic[T, D], ABC): - ldm: T +class MultiDiffusion(ABC, Generic[T]): + """ + MultiDiffusion class for performing multi-target diffusion using tiled diffusion. - def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor: + For more details, refer to the paper: [MultiDiffusion](https://arxiv.org/abs/2302.08113) + """ + + def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: Sequence[T]) -> Tensor: num_updates = torch.zeros_like(input=x) cumulative_values = torch.zeros_like(input=x) @@ -51,7 +100,7 @@ class MultiDiffusion(Generic[T, D], ABC): match step: case step if step == target.start_step and target.init_latents is not None: noise_view = target.crop(noise) - view = self.ldm.solver.add_noise( + view = target.solver.add_noise( x=target.init_latents, noise=noise_view, step=step, @@ -61,44 +110,60 @@ class MultiDiffusion(Generic[T, D], ABC): case _: continue view = self.diffuse_target(x=view, step=step, target=target) - weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight + weight = target.weight * target.opacity_mask if target.opacity_mask is not None else target.weight num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight) cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view) return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x) @abstractmethod - def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ... - - @property - def steps(self) -> list[int]: - return self.ldm.steps - - @property - def device(self) -> Device: - return self.ldm.device - - @property - def dtype(self) -> DType: - return self.ldm.dtype - - # backward-compatibility alias - def decode_latents(self, x: Tensor) -> Image.Image: - return self.latents_to_image(x=x) - - def latents_to_image(self, x: Tensor) -> Image.Image: - return self.ldm.lda.latents_to_image(x=x) - - def latents_to_images(self, x: Tensor) -> list[Image.Image]: - return self.ldm.lda.latents_to_images(x=x) + def diffuse_target(self, x: Tensor, step: int, target: T) -> Tensor: ... @staticmethod - def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]: - height, width = size + def generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]: + """ + Generate tiles for a latent image with the given size and tile size. - return [ - (y, x) - for y in range(0, height, stride) - for x in range(0, width, stride) - if y + 64 <= height and x + 64 <= width - ] + If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is + used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size + is respected as much as possible, while still covering the entire image and respecting the minimum overlap. + """ + assert ( + 0 <= min_overlap < min(tile_size.height, tile_size.width) + ), "Overlap must be non-negative and less than the tile size" + + if tile_size.width > size.width or tile_size.height > size.height: + return [Tile(top=0, left=0, bottom=size.height, right=size.width)] + + tiles: list[Tile] = [] + + def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]: + if tile_length >= length: + return 1, 0 + num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1 + overlap = (num_tiles * tile_length - length) // (num_tiles - 1) + return num_tiles, overlap + + num_tiles_x, overlap_x = _compute_tiles_and_overlap( + length=size.width, tile_length=tile_size.width, min_overlap=min_overlap + ) + num_tiles_y, overlap_y = _compute_tiles_and_overlap( + length=size.height, tile_length=tile_size.height, min_overlap=min_overlap + ) + + for i in range(num_tiles_y): + for j in range(num_tiles_x): + x = j * (tile_size.width - overlap_x) + y = i * (tile_size.height - overlap_y) + + # Adjust x and y coordinates to ensure full-sized tiles + if x + tile_size.width > size.width: + x = size.width - tile_size.width + if y + tile_size.height > size.height: + y = size.height - tile_size.height + + tile_right = x + tile_size.width + tile_bottom = y + tile_size.height + tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right)) + + return tiles diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py index 0355d85..592b69c 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_diffusion.py @@ -1,41 +1,31 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass -from PIL import Image from torch import Tensor from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( StableDiffusion_1, - StableDiffusion_1_Inpainting, ) -class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]): - def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor: - return self.ldm( - x=x, - step=step, - clip_text_embedding=target.clip_text_embedding, - condition_scale=target.condition_scale, - ) - - -@dataclass -class InpaintingDiffusionTarget(DiffusionTarget): - target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255)) - mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255)) - - -class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]): - def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor: - self.ldm.set_inpainting_conditions( - target_image=target.target_image, - mask=target.mask, - ) - - return self.ldm( +@dataclass(kw_only=True) +class SD1DiffusionTarget(DiffusionTarget): + clip_text_embedding: Tensor + condition_scale: float = 7.0 + + +class SD1MultiDiffusion(MultiDiffusion[SD1DiffusionTarget]): + def __init__(self, sd: StableDiffusion_1) -> None: + self.sd = sd + + def diffuse_target(self, x: Tensor, step: int, target: SD1DiffusionTarget) -> Tensor: + old_solver = self.sd.solver + self.sd.solver = target.solver + result = self.sd( x=x, step=step, clip_text_embedding=target.clip_text_embedding, condition_scale=target.condition_scale, ) + self.sd.solver = old_solver + return result diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/multi_diffusion.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/multi_diffusion.py index 2d905e0..29e7b7c 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/multi_diffusion.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/multi_diffusion.py @@ -1,17 +1,27 @@ +from dataclasses import dataclass + from torch import Tensor +from refiners.foundationals.latent_diffusion import StableDiffusion_XL from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion -from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL -class SDXLDiffusionTarget(DiffusionTarget): +@dataclass(kw_only=True) +class SDXLTarget(DiffusionTarget): + clip_text_embedding: Tensor + condition_scale: float = 5.0 pooled_text_embedding: Tensor time_ids: Tensor -class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]): - def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor: - return self.ldm( +class SDXLMultiDiffusion(MultiDiffusion[SDXLTarget]): + def __init__(self, sd: StableDiffusion_XL) -> None: + self.sd = sd + + def diffuse_target(self, x: Tensor, step: int, target: SDXLTarget) -> Tensor: + old_solver = self.sd.solver + self.sd.solver = target.solver + result = self.sd( x=x, step=step, clip_text_embedding=target.clip_text_embedding, @@ -19,3 +29,5 @@ class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget] time_ids=target.time_ids, condition_scale=target.condition_scale, ) + self.sd.solver = old_solver + return result