mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
improve MultiDiffusion pipelines
Co-authored-by: limiteinductive <benjamin@lagon.tech> Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
b4db08de24
commit
66cd0d57a1
|
@ -1,28 +1,75 @@
|
||||||
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Generic, TypeVar
|
from typing import Generic, NamedTuple, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from torch import Tensor
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
from refiners.foundationals.latent_diffusion.solvers.solver import Solver
|
||||||
|
|
||||||
MAX_STEPS = 1000
|
MAX_STEPS = 1000
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class Tile(NamedTuple):
|
||||||
|
top: int
|
||||||
|
left: int
|
||||||
|
bottom: int
|
||||||
|
right: int
|
||||||
|
|
||||||
|
|
||||||
|
class Size(NamedTuple):
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(kw_only=True)
|
||||||
class DiffusionTarget:
|
class DiffusionTarget:
|
||||||
size: tuple[int, int]
|
"""
|
||||||
offset: tuple[int, int]
|
Represents a target for the tiled diffusion process.
|
||||||
clip_text_embedding: Tensor
|
|
||||||
|
This class encapsulates the parameters and properties needed to define a specific area (target) within a larger
|
||||||
|
diffusion process, allowing for fine-grained control over different regions of the generated image.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
tile: The tile defining the area of the target within the latent image.
|
||||||
|
solver: The solver to use for this target's diffusion process. This is useful because some solvers have an
|
||||||
|
internal state that needs to be updated during the diffusion process. Using the same solver instance for
|
||||||
|
multiple targets would interfere with this internal state.
|
||||||
|
init_latents: The initial latents for this target. If None, the target will be initialized with noise.
|
||||||
|
opacity_mask: Mask controlling the target's visibility in the final image.
|
||||||
|
If None, the target will be fully visible. Otherwise, 1 means fully opaque and 0 means fully transparent
|
||||||
|
which means the target has no influence.
|
||||||
|
weight: The importance of this target in the final image. Higher values increase the target's influence.
|
||||||
|
start_step: The diffusion step at which this target begins to influence the process.
|
||||||
|
end_step: The diffusion step at which this target stops influencing the process.
|
||||||
|
size: The size of the target area.
|
||||||
|
offset: The top-left offset of the target area within the latent image.
|
||||||
|
|
||||||
|
The combination of `opacity_mask` and `weight` determines the target's overall contribution to the final generated
|
||||||
|
image. The `solver` is responsible for the actual diffusion calculations for this target.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tile: Tile
|
||||||
|
solver: Solver
|
||||||
init_latents: Tensor | None = None
|
init_latents: Tensor | None = None
|
||||||
mask_latent: Tensor | None = None
|
opacity_mask: Tensor | None = None
|
||||||
weight: int = 1
|
weight: int = 1
|
||||||
condition_scale: float = 7.5
|
|
||||||
start_step: int = 0
|
start_step: int = 0
|
||||||
end_step: int = MAX_STEPS
|
end_step: int = MAX_STEPS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> Size:
|
||||||
|
return Size(
|
||||||
|
height=self.tile.bottom - self.tile.top,
|
||||||
|
width=self.tile.right - self.tile.left,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def offset(self) -> tuple[int, int]:
|
||||||
|
return self.tile.top, self.tile.left
|
||||||
|
|
||||||
def crop(self, tensor: Tensor, /) -> Tensor:
|
def crop(self, tensor: Tensor, /) -> Tensor:
|
||||||
height, width = self.size
|
height, width = self.size
|
||||||
top_offset, left_offset = self.offset
|
top_offset, left_offset = self.offset
|
||||||
|
@ -35,15 +82,17 @@ class DiffusionTarget:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=LatentDiffusionModel)
|
T = TypeVar("T", bound=DiffusionTarget)
|
||||||
D = TypeVar("D", bound=DiffusionTarget)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class MultiDiffusion(ABC, Generic[T]):
|
||||||
class MultiDiffusion(Generic[T, D], ABC):
|
"""
|
||||||
ldm: T
|
MultiDiffusion class for performing multi-target diffusion using tiled diffusion.
|
||||||
|
|
||||||
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor:
|
For more details, refer to the paper: [MultiDiffusion](https://arxiv.org/abs/2302.08113)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: Sequence[T]) -> Tensor:
|
||||||
num_updates = torch.zeros_like(input=x)
|
num_updates = torch.zeros_like(input=x)
|
||||||
cumulative_values = torch.zeros_like(input=x)
|
cumulative_values = torch.zeros_like(input=x)
|
||||||
|
|
||||||
|
@ -51,7 +100,7 @@ class MultiDiffusion(Generic[T, D], ABC):
|
||||||
match step:
|
match step:
|
||||||
case step if step == target.start_step and target.init_latents is not None:
|
case step if step == target.start_step and target.init_latents is not None:
|
||||||
noise_view = target.crop(noise)
|
noise_view = target.crop(noise)
|
||||||
view = self.ldm.solver.add_noise(
|
view = target.solver.add_noise(
|
||||||
x=target.init_latents,
|
x=target.init_latents,
|
||||||
noise=noise_view,
|
noise=noise_view,
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -61,44 +110,60 @@ class MultiDiffusion(Generic[T, D], ABC):
|
||||||
case _:
|
case _:
|
||||||
continue
|
continue
|
||||||
view = self.diffuse_target(x=view, step=step, target=target)
|
view = self.diffuse_target(x=view, step=step, target=target)
|
||||||
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight
|
weight = target.weight * target.opacity_mask if target.opacity_mask is not None else target.weight
|
||||||
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
|
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
|
||||||
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
|
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
|
||||||
|
|
||||||
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ...
|
def diffuse_target(self, x: Tensor, step: int, target: T) -> Tensor: ...
|
||||||
|
|
||||||
@property
|
|
||||||
def steps(self) -> list[int]:
|
|
||||||
return self.ldm.steps
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> Device:
|
|
||||||
return self.ldm.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self) -> DType:
|
|
||||||
return self.ldm.dtype
|
|
||||||
|
|
||||||
# backward-compatibility alias
|
|
||||||
def decode_latents(self, x: Tensor) -> Image.Image:
|
|
||||||
return self.latents_to_image(x=x)
|
|
||||||
|
|
||||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
|
||||||
return self.ldm.lda.latents_to_image(x=x)
|
|
||||||
|
|
||||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
|
||||||
return self.ldm.lda.latents_to_images(x=x)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
|
def generate_latent_tiles(size: Size, tile_size: Size, min_overlap: int = 8) -> list[Tile]:
|
||||||
height, width = size
|
"""
|
||||||
|
Generate tiles for a latent image with the given size and tile size.
|
||||||
|
|
||||||
return [
|
If one dimension of the `tile_size` is larger than the corresponding dimension of the image size, a single tile is
|
||||||
(y, x)
|
used to cover the entire image - and therefore `tile_size` is ignored. This algorithm ensures that the tile size
|
||||||
for y in range(0, height, stride)
|
is respected as much as possible, while still covering the entire image and respecting the minimum overlap.
|
||||||
for x in range(0, width, stride)
|
"""
|
||||||
if y + 64 <= height and x + 64 <= width
|
assert (
|
||||||
]
|
0 <= min_overlap < min(tile_size.height, tile_size.width)
|
||||||
|
), "Overlap must be non-negative and less than the tile size"
|
||||||
|
|
||||||
|
if tile_size.width > size.width or tile_size.height > size.height:
|
||||||
|
return [Tile(top=0, left=0, bottom=size.height, right=size.width)]
|
||||||
|
|
||||||
|
tiles: list[Tile] = []
|
||||||
|
|
||||||
|
def _compute_tiles_and_overlap(length: int, tile_length: int, min_overlap: int) -> tuple[int, int]:
|
||||||
|
if tile_length >= length:
|
||||||
|
return 1, 0
|
||||||
|
num_tiles = math.ceil((length - tile_length) / (tile_length - min_overlap)) + 1
|
||||||
|
overlap = (num_tiles * tile_length - length) // (num_tiles - 1)
|
||||||
|
return num_tiles, overlap
|
||||||
|
|
||||||
|
num_tiles_x, overlap_x = _compute_tiles_and_overlap(
|
||||||
|
length=size.width, tile_length=tile_size.width, min_overlap=min_overlap
|
||||||
|
)
|
||||||
|
num_tiles_y, overlap_y = _compute_tiles_and_overlap(
|
||||||
|
length=size.height, tile_length=tile_size.height, min_overlap=min_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(num_tiles_y):
|
||||||
|
for j in range(num_tiles_x):
|
||||||
|
x = j * (tile_size.width - overlap_x)
|
||||||
|
y = i * (tile_size.height - overlap_y)
|
||||||
|
|
||||||
|
# Adjust x and y coordinates to ensure full-sized tiles
|
||||||
|
if x + tile_size.width > size.width:
|
||||||
|
x = size.width - tile_size.width
|
||||||
|
if y + tile_size.height > size.height:
|
||||||
|
y = size.height - tile_size.height
|
||||||
|
|
||||||
|
tile_right = x + tile_size.width
|
||||||
|
tile_bottom = y + tile_size.height
|
||||||
|
tiles.append(Tile(top=y, left=x, bottom=tile_bottom, right=tile_right))
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
|
@ -1,41 +1,31 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
StableDiffusion_1,
|
StableDiffusion_1,
|
||||||
StableDiffusion_1_Inpainting,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
|
@dataclass(kw_only=True)
|
||||||
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
|
class SD1DiffusionTarget(DiffusionTarget):
|
||||||
return self.ldm(
|
clip_text_embedding: Tensor
|
||||||
x=x,
|
condition_scale: float = 7.0
|
||||||
step=step,
|
|
||||||
clip_text_embedding=target.clip_text_embedding,
|
|
||||||
condition_scale=target.condition_scale,
|
class SD1MultiDiffusion(MultiDiffusion[SD1DiffusionTarget]):
|
||||||
)
|
def __init__(self, sd: StableDiffusion_1) -> None:
|
||||||
|
self.sd = sd
|
||||||
|
|
||||||
@dataclass
|
def diffuse_target(self, x: Tensor, step: int, target: SD1DiffusionTarget) -> Tensor:
|
||||||
class InpaintingDiffusionTarget(DiffusionTarget):
|
old_solver = self.sd.solver
|
||||||
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
|
self.sd.solver = target.solver
|
||||||
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
|
result = self.sd(
|
||||||
|
|
||||||
|
|
||||||
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
|
|
||||||
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
|
|
||||||
self.ldm.set_inpainting_conditions(
|
|
||||||
target_image=target.target_image,
|
|
||||||
mask=target.mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.ldm(
|
|
||||||
x=x,
|
x=x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=target.clip_text_embedding,
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
condition_scale=target.condition_scale,
|
condition_scale=target.condition_scale,
|
||||||
)
|
)
|
||||||
|
self.sd.solver = old_solver
|
||||||
|
return result
|
||||||
|
|
|
@ -1,17 +1,27 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion import StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
|
||||||
|
|
||||||
|
|
||||||
class SDXLDiffusionTarget(DiffusionTarget):
|
@dataclass(kw_only=True)
|
||||||
|
class SDXLTarget(DiffusionTarget):
|
||||||
|
clip_text_embedding: Tensor
|
||||||
|
condition_scale: float = 5.0
|
||||||
pooled_text_embedding: Tensor
|
pooled_text_embedding: Tensor
|
||||||
time_ids: Tensor
|
time_ids: Tensor
|
||||||
|
|
||||||
|
|
||||||
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
|
class SDXLMultiDiffusion(MultiDiffusion[SDXLTarget]):
|
||||||
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
|
def __init__(self, sd: StableDiffusion_XL) -> None:
|
||||||
return self.ldm(
|
self.sd = sd
|
||||||
|
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: SDXLTarget) -> Tensor:
|
||||||
|
old_solver = self.sd.solver
|
||||||
|
self.sd.solver = target.solver
|
||||||
|
result = self.sd(
|
||||||
x=x,
|
x=x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=target.clip_text_embedding,
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
@ -19,3 +29,5 @@ class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]
|
||||||
time_ids=target.time_ids,
|
time_ids=target.time_ids,
|
||||||
condition_scale=target.condition_scale,
|
condition_scale=target.condition_scale,
|
||||||
)
|
)
|
||||||
|
self.sd.solver = old_solver
|
||||||
|
return result
|
||||||
|
|
Loading…
Reference in a new issue