mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
implement multi_diffusion for SD1 and SDXL
This commit is contained in:
parent
b86521da2f
commit
85095418aa
|
@ -0,0 +1,40 @@
|
|||
from dataclasses import field, dataclass
|
||||
from torch import Tensor
|
||||
from PIL import Image
|
||||
|
||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||
StableDiffusion_1,
|
||||
StableDiffusion_1_Inpainting,
|
||||
)
|
||||
|
||||
|
||||
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
|
||||
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
|
||||
return self.ldm(
|
||||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=target.clip_text_embedding,
|
||||
scale=target.condition_scale,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InpaintingDiffusionTarget(DiffusionTarget):
|
||||
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
|
||||
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
|
||||
|
||||
|
||||
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
|
||||
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
|
||||
self.ldm.set_inpainting_conditions(
|
||||
target_image=target.target_image,
|
||||
mask=target.mask,
|
||||
)
|
||||
|
||||
return self.ldm(
|
||||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=target.clip_text_embedding,
|
||||
scale=target.condition_scale,
|
||||
)
|
|
@ -0,0 +1,21 @@
|
|||
from torch import Tensor
|
||||
|
||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||
|
||||
|
||||
class SDXLDiffusionTarget(DiffusionTarget):
|
||||
pooled_text_embedding: Tensor
|
||||
time_ids: Tensor
|
||||
|
||||
|
||||
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
|
||||
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
|
||||
return self.ldm(
|
||||
x=x,
|
||||
step=step,
|
||||
clip_text_embedding=target.clip_text_embedding,
|
||||
pooled_text_embedding=target.pooled_text_embedding,
|
||||
time_ids=target.time_ids,
|
||||
condition_scale=target.condition_scale,
|
||||
)
|
Loading…
Reference in a new issue