mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
implement multi_diffusion for SD1 and SDXL
This commit is contained in:
parent
b86521da2f
commit
85095418aa
|
@ -0,0 +1,40 @@
|
||||||
|
from dataclasses import field, dataclass
|
||||||
|
from torch import Tensor
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
StableDiffusion_1,
|
||||||
|
StableDiffusion_1_Inpainting,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SD1MultiDiffusion(MultiDiffusion[StableDiffusion_1, DiffusionTarget]):
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: DiffusionTarget) -> Tensor:
|
||||||
|
return self.ldm(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
scale=target.condition_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InpaintingDiffusionTarget(DiffusionTarget):
|
||||||
|
target_image: Image.Image = field(default_factory=lambda: Image.new(mode="RGB", size=(512, 512), color=255))
|
||||||
|
mask: Image.Image = field(default_factory=lambda: Image.new(mode="L", size=(512, 512), color=255))
|
||||||
|
|
||||||
|
|
||||||
|
class SD1InpaintingMultiDiffusion(MultiDiffusion[StableDiffusion_1_Inpainting, InpaintingDiffusionTarget]):
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: InpaintingDiffusionTarget) -> Tensor:
|
||||||
|
self.ldm.set_inpainting_conditions(
|
||||||
|
target_image=target.target_image,
|
||||||
|
mask=target.mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.ldm(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
scale=target.condition_scale,
|
||||||
|
)
|
|
@ -0,0 +1,21 @@
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLDiffusionTarget(DiffusionTarget):
|
||||||
|
pooled_text_embedding: Tensor
|
||||||
|
time_ids: Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLMultiDiffusion(MultiDiffusion[StableDiffusion_XL, SDXLDiffusionTarget]):
|
||||||
|
def diffuse_target(self, x: Tensor, step: int, target: SDXLDiffusionTarget) -> Tensor:
|
||||||
|
return self.ldm(
|
||||||
|
x=x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=target.clip_text_embedding,
|
||||||
|
pooled_text_embedding=target.pooled_text_embedding,
|
||||||
|
time_ids=target.time_ids,
|
||||||
|
condition_scale=target.condition_scale,
|
||||||
|
)
|
Loading…
Reference in a new issue