implement multi_diffusion for SD1 and SDXL

This commit is contained in:
Benjamin Trom 2023-09-18 10:47:44 +02:00
parent b86521da2f
commit 85095418aa
2 changed files with 61 additions and 0 deletions

View file

@ -0,0 +1,40 @@
from dataclasses import field, dataclass
from torch import Tensor
from PIL import Image
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
scale=target.condition_scale,
)
@dataclass
class InpaintingDiffusionTarget(DiffusionTarget):
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
self.ldm.set_inpainting_conditions(
target_image=target.target_image,
mask=target.mask,
)
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
scale=target.condition_scale,
)

View file

@ -0,0 +1,21 @@
from torch import Tensor
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
class SDXLDiffusionTarget(DiffusionTarget):
pooled_text_embedding: Tensor
time_ids: Tensor
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
return self.ldm(
x=x,
step=step,
clip_text_embedding=target.clip_text_embedding,
pooled_text_embedding=target.pooled_text_embedding,
time_ids=target.time_ids,
condition_scale=target.condition_scale,
)