mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
implement abstract MultiDiffusion class
This commit is contained in:
parent
e319f13d05
commit
b86521da2f
|
@ -0,0 +1,98 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
|
||||||
|
MAX_STEPS = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiffusionTarget:
|
||||||
|
size: tuple[int, int]
|
||||||
|
offset: tuple[int, int]
|
||||||
|
clip_text_embedding: Tensor
|
||||||
|
init_latents: Tensor | None = None
|
||||||
|
mask_latent: Tensor | None = None
|
||||||
|
weight: int = 1
|
||||||
|
condition_scale: float = 7.5
|
||||||
|
start_step: int = 0
|
||||||
|
end_step: int = MAX_STEPS
|
||||||
|
|
||||||
|
def crop(self, tensor: Tensor, /) -> Tensor:
|
||||||
|
height, width = self.size
|
||||||
|
top_offset, left_offset = self.offset
|
||||||
|
return tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width]
|
||||||
|
|
||||||
|
def paste(self, tensor: Tensor, /, crop: Tensor) -> Tensor:
|
||||||
|
height, width = self.size
|
||||||
|
top_offset, left_offset = self.offset
|
||||||
|
tensor[:, :, top_offset : top_offset + height, left_offset : left_offset + width] = crop
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||||
|
D = TypeVar("D", bound=DiffusionTarget)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiDiffusion(Generic[T, D], ABC):
|
||||||
|
ldm: T
|
||||||
|
|
||||||
|
def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) -> Tensor:
|
||||||
|
num_updates = torch.zeros_like(input=x)
|
||||||
|
cumulative_values = torch.zeros_like(input=x)
|
||||||
|
|
||||||
|
for target in targets:
|
||||||
|
match step:
|
||||||
|
case step if step == target.start_step and target.init_latents is not None:
|
||||||
|
noise_view = target.crop(noise)
|
||||||
|
view = self.ldm.scheduler.add_noise(
|
||||||
|
x=target.init_latents,
|
||||||
|
noise=noise_view,
|
||||||
|
step=step,
|
||||||
|
)
|
||||||
|
case step if target.start_step <= step <= target.end_step:
|
||||||
|
view = target.crop(x)
|
||||||
|
case _:
|
||||||
|
continue
|
||||||
|
view = self.diffuse_target(x=view, step=step, target=target)
|
||||||
|
weight = target.weight * target.mask_latent if target.mask_latent is not None else target.weight
|
||||||
|
num_updates = target.paste(num_updates, crop=target.crop(num_updates) + weight)
|
||||||
|
cumulative_values = target.paste(cumulative_values, crop=target.crop(cumulative_values) + weight * view)
|
||||||
|
|
||||||
|
return torch.where(condition=num_updates > 0, input=cumulative_values / num_updates, other=x)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: D) -> Tensor: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def steps(self) -> list[int]:
|
||||||
|
return self.ldm.steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> Device:
|
||||||
|
return self.ldm.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> DType:
|
||||||
|
return self.ldm.dtype
|
||||||
|
|
||||||
|
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||||
|
return self.ldm.lda.decode_latents(x=x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_offset_grid(size: tuple[int, int], stride: int = 8) -> list[tuple[int, int]]:
|
||||||
|
height, width = size
|
||||||
|
|
||||||
|
return [
|
||||||
|
(y, x)
|
||||||
|
for y in range(0, height, stride)
|
||||||
|
for x in range(0, width, stride)
|
||||||
|
if y + 64 <= height and x + 64 <= width
|
||||||
|
]
|
Loading…
Reference in a new issue