implement abstract MultiDiffusion class

This commit is contained in:
Benjamin Trom 2023-09-18 10:47:01 +02:00
parent e319f13d05
commit b86521da2f

View file

@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic, TypeVar
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
MAX_STEPS = 1000
@dataclass
class DiffusionTarget:
size: tuple[int, int]
offset: tuple[int, int]
clip_text_embedding: Tensor
init_latents: Tensor | None = None
mask_latent: Tensor | None = None
weight: int = 1
condition_scale: float = 7.5
start_step: int = 0
end_step: int = MAX_STEPS
def crop(self, tensor: Tensor, /) -> Tensor:
height, width = self.size
top_offset, left_offset = self.offset
return tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width]
def paste(self, tensor: Tensor, /, crop: Tensor) -> Tensor:
height, width = self.size
top_offset, left_offset = self.offset
tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width] = crop
return tensor
T = TypeVar("T", bound=LatentDiffusionModel)
D = TypeVar("D", bound=DiffusionTarget)
@dataclass
class MultiDiffusion(Generic[T, D], ABC):
ldm: T
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor:
num_updates = torch.zeros_like(input=x)
cumulative_values = torch.zeros_like(input=x)
for target in targets:
match step:
case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise)
view = self.ldm.scheduler.add_noise(
x=target.init_latents,
noise=noise_view,
step=step,
)
case step if target.start_step <= step <= target.end_step:
view = target.crop(x)
case _:
continue
view = self.diffuse_target(x=view, step=step, target=target)
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
@abstractmethod
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ...
@property
def steps(self) -> list[int]:
return self.ldm.steps
@property
def device(self) -> Device:
return self.ldm.device
@property
def dtype(self) -> DType:
return self.ldm.dtype
def decode_latents(self, x: Tensor) -> Image.Image:
return self.ldm.lda.decode_latents(x=x)
@staticmethod
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
height, width = size
return [
(y, x)
for y in range(0, height, stride)
for x in range(0, width, stride)
if y + 64 <= height and x + 64 <= width
]