improve MultiDiffusion pipelines

Co-authored-by: limiteinductive <benjamin@lagon.tech>
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Laurent 2024-07-11 13:04:06 +00:00 committed by Laureηt
parent b4db08de24
commit 66cd0d57a1
3 changed files with 148 additions and 81 deletions

View file

@ -1,28 +1,75 @@
import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Generic, TypeVar from typing import Generic, NamedTuple, Sequence
import torch import torch
from PIL import Image from torch import Tensor
from torch import Tensor, device as Device, dtype as DType from typing_extensions import TypeVar
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel from refiners.foundationals.latent_diffusion.solvers.solver import Solver
MAX_STEPS = 1000 MAX_STEPS = 1000
@dataclass class Tile(NamedTuple):
top: int
left: int
bottom: int
right: int
class Size(NamedTuple):
height: int
width: int
@dataclass(kw_only=True)
class DiffusionTarget: class DiffusionTarget:
size: tuple[int, int] """
offset: tuple[int, int] Represents a target for the tiled diffusion process.
clip_text_embedding: Tensor
This class encapsulates the parameters and properties needed to define a specific area (target) within a larger
diffusion process, allowing for fine-grained control over different regions of the generated image.
Attributes:
tile: The tile defining the area of the target within the latent image.
solver: The solver to use for this target's diffusion process. This is useful because some solvers have an
internal state that needs to be updated during the diffusion process. Using the same solver instance for
multiple targets would interfere with this internal state.
init_latents: The initial latents for this target. If None, the target will be initialized with noise.
opacity_mask: Mask controlling the target's visibility in the final image.
If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent
which means the target has no influence.
weight: The importance of this target in the final image. Higher values increase the target's influence.
start_step: The diffusion step at which this target begins to influence the process.
end_step: The diffusion step at which this target stops influencing the process.
size: The size of the target area.
offset: The top-left offset of the target area within the latent image.
The combination of `opacity_mask` and `weight` determines the target's overall contribution to the final generated
image. The `solver` is responsible for the actual diffusion calculations for this target.
"""
tile: Tile
solver: Solver
init_latents: Tensor | None = None init_latents: Tensor | None = None
mask_latent: Tensor | None = None opacity_mask: Tensor | None = None
weight: int = 1 weight: int = 1
condition_scale: float = 7.5
start_step: int = 0 start_step: int = 0
end_step: int = MAX_STEPS end_step: int = MAX_STEPS
@property
def size(self) -> Size:
return Size(
height=self.tile.bottom - self.tile.top,
width=self.tile.right - self.tile.left,
)
@property
def offset(self) -> tuple[int, int]:
return self.tile.top, self.tile.left
def crop(self, tensor: Tensor, /) -> Tensor: def crop(self, tensor: Tensor, /) -> Tensor:
height, width = self.size height, width = self.size
top_offset, left_offset = self.offset top_offset, left_offset = self.offset
@ -35,15 +82,17 @@ class DiffusionTarget:
return tensor return tensor
T = TypeVar("T", bound=LatentDiffusionModel) T = TypeVar("T", bound=DiffusionTarget)
D = TypeVar("D", bound=DiffusionTarget)
@dataclass class MultiDiffusion(ABC, Generic[T]):
class MultiDiffusion(Generic[T, D], ABC): """
ldm: T MultiDiffusion class for performing multi-target diffusion using tiled diffusion.
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor: For more details, refer to the paper: [MultiDiffusion](https://arxiv.org/abs/2302.08113)
"""
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: Sequence[T]) -> Tensor:
num_updates = torch.zeros_like(input=x) num_updates = torch.zeros_like(input=x)
cumulative_values = torch.zeros_like(input=x) cumulative_values = torch.zeros_like(input=x)
@ -51,7 +100,7 @@ class MultiDiffusion(Generic[T, D], ABC):
match step: match step:
case step if step == target.start_step and target.init_latents is not None: case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise) noise_view = target.crop(noise)
view = self.ldm.solver.add_noise( view = target.solver.add_noise(
x=target.init_latents, x=target.init_latents,
noise=noise_view, noise=noise_view,
step=step, step=step,
@ -61,44 +110,60 @@ class MultiDiffusion(Generic[T, D], ABC):
case _: case _:
continue continue
view = self.diffuse_target(x=view, step=step, target=target) view = self.diffuse_target(x=view, step=step, target=target)
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight weight = target.weight * target.opacity_mask if target.opacity_mask is not None else target.weight
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight) num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view) cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x) return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
@abstractmethod @abstractmethod
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ... def diffuse_target(self, x: Tensor, step: int, target: T) -> Tensor: ...
@property
def steps(self) -> list[int]:
return self.ldm.steps
@property
def device(self) -> Device:
return self.ldm.device
@property
def dtype(self) -> DType:
return self.ldm.dtype
# backward-compatibility alias
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x=x)
def latents_to_image(self, x: Tensor) -> Image.Image:
return self.ldm.lda.latents_to_image(x=x)
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
return self.ldm.lda.latents_to_images(x=x)
@staticmethod @staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]: def generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]:
height, width = size """
Generate tiles for a latent image with the given size and tile size.
return [ If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is
(y, x) used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size
for y in range(0, height, stride) is respected as much as possible, while still covering the entire image and respecting the minimum overlap.
for x in range(0, width, stride) """
if y + 64 <= height and x + 64 <= width assert (
] 0 <= min_overlap < min(tile_size.height, tile_size.width)
), "Overlap must be non-negative and less than the tile size"
if tile_size.width > size.width or tile_size.height > size.height:
return [Tile(top=0, left=0, bottom=size.height, right=size.width)]
tiles: list[Tile] = []
def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]:
if tile_length >= length:
return 1, 0
num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1
overlap = (num_tiles * tile_length - length) // (num_tiles - 1)
return num_tiles, overlap
num_tiles_x, overlap_x = _compute_tiles_and_overlap(
length=size.width, tile_length=tile_size.width, min_overlap=min_overlap
)
num_tiles_y, overlap_y = _compute_tiles_and_overlap(
length=size.height, tile_length=tile_size.height, min_overlap=min_overlap
)
for i in range(num_tiles_y):
for j in range(num_tiles_x):
x = j * (tile_size.width - overlap_x)
y = i * (tile_size.height - overlap_y)
# Adjust x and y coordinates to ensure full-sized tiles
if x + tile_size.width > size.width:
x = size.width - tile_size.width
if y + tile_size.height > size.height:
y = size.height - tile_size.height
tile_right = x + tile_size.width
tile_bottom = y + tile_size.height
tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right))
return tiles

View file

@ -1,41 +1,31 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from PIL import Image
from torch import Tensor from torch import Tensor
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import ( from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1, StableDiffusion_1,
StableDiffusion_1_Inpainting,
) )
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]): @dataclass(kw_only=True)
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor: class SD1DiffusionTarget(DiffusionTarget):
return self.ldm( clip_text_embedding: Tensor
x=x, condition_scale: float = 7.0
step=step,
clip_text_embedding=target.clip_text_embedding,
condition_scale=target.condition_scale, class SD1MultiDiffusion(MultiDiffusion[SD1DiffusionTarget]):
) def __init__(self, sd: StableDiffusion_1) -> None:
self.sd = sd
@dataclass def diffuse_target(self, x: Tensor, step: int, target: SD1DiffusionTarget) -> Tensor:
class InpaintingDiffusionTarget(DiffusionTarget): old_solver = self.sd.solver
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255)) self.sd.solver = target.solver
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255)) result = self.sd(
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
self.ldm.set_inpainting_conditions(
target_image=target.target_image,
mask=target.mask,
)
return self.ldm(
x=x, x=x,
step=step, step=step,
clip_text_embedding=target.clip_text_embedding, clip_text_embedding=target.clip_text_embedding,
condition_scale=target.condition_scale, condition_scale=target.condition_scale,
) )
self.sd.solver = old_solver
return result

View file

@ -1,17 +1,27 @@
from dataclasses import dataclass
from torch import Tensor from torch import Tensor
from refiners.foundationals.latent_diffusion import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
class SDXLDiffusionTarget(DiffusionTarget): @dataclass(kw_only=True)
class SDXLTarget(DiffusionTarget):
clip_text_embedding: Tensor
condition_scale: float = 5.0
pooled_text_embedding: Tensor pooled_text_embedding: Tensor
time_ids: Tensor time_ids: Tensor
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]): class SDXLMultiDiffusion(MultiDiffusion[SDXLTarget]):
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor: def __init__(self, sd: StableDiffusion_XL) -> None:
return self.ldm( self.sd = sd
def diffuse_target(self, x: Tensor, step: int, target: SDXLTarget) -> Tensor:
old_solver = self.sd.solver
self.sd.solver = target.solver
result = self.sd(
x=x, x=x,
step=step, step=step,
clip_text_embedding=target.clip_text_embedding, clip_text_embedding=target.clip_text_embedding,
@ -19,3 +29,5 @@ class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]
time_ids=target.time_ids, time_ids=target.time_ids,
condition_scale=target.condition_scale, condition_scale=target.condition_scale,
) )
self.sd.solver = old_solver
return result