mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(doc/fluxion/ld) add StableDiffusion_1
docstrings
This commit is contained in:
parent
af3a29f916
commit
78e9f7728e
|
@ -1,6 +1,7 @@
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
|
||||||
|
SD1Autoencoder,
|
||||||
StableDiffusion_1,
|
StableDiffusion_1,
|
||||||
StableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting,
|
||||||
)
|
)
|
||||||
|
@ -10,6 +11,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1U
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"StableDiffusion_1",
|
"StableDiffusion_1",
|
||||||
"StableDiffusion_1_Inpainting",
|
"StableDiffusion_1_Inpainting",
|
||||||
|
"SD1Autoencoder",
|
||||||
"SD1UNet",
|
"SD1UNet",
|
||||||
"SD1ControlnetAdapter",
|
"SD1ControlnetAdapter",
|
||||||
"SD1IPAdapter",
|
"SD1IPAdapter",
|
||||||
|
|
|
@ -13,10 +13,24 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1U
|
||||||
|
|
||||||
|
|
||||||
class SD1Autoencoder(LatentDiffusionAutoencoder):
|
class SD1Autoencoder(LatentDiffusionAutoencoder):
|
||||||
|
"""Stable Diffusion 1.5 autoencoder model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
encoder_scale: The encoder scale to use.
|
||||||
|
"""
|
||||||
|
|
||||||
encoder_scale: float = 0.18215
|
encoder_scale: float = 0.18215
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion_1(LatentDiffusionModel):
|
class StableDiffusion_1(LatentDiffusionModel):
|
||||||
|
"""Stable Diffusion 1.5 model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
unet: The U-Net model.
|
||||||
|
clip_text_encoder: The text encoder.
|
||||||
|
lda: The image autoencoder.
|
||||||
|
"""
|
||||||
|
|
||||||
unet: SD1UNet
|
unet: SD1UNet
|
||||||
clip_text_encoder: CLIPTextEncoderL
|
clip_text_encoder: CLIPTextEncoderL
|
||||||
lda: SD1Autoencoder
|
lda: SD1Autoencoder
|
||||||
|
@ -30,6 +44,16 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType = torch.float32,
|
dtype: DType = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initializes the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
unet: The SD1UNet U-Net model to use.
|
||||||
|
lda: The SD1Autoencoder image autoencoder to use.
|
||||||
|
clip_text_encoder: The CLIPTextEncoderL text encoder to use.
|
||||||
|
solver: The solver to use.
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
unet = unet or SD1UNet(in_channels=4)
|
unet = unet or SD1UNet(in_channels=4)
|
||||||
lda = lda or SD1Autoencoder()
|
lda = lda or SD1Autoencoder()
|
||||||
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
|
||||||
|
@ -45,6 +69,13 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
|
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
|
||||||
|
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The prompt to compute the CLIP text embedding of.
|
||||||
|
negative_text: The negative prompt to compute the CLIP text embedding of.
|
||||||
|
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
||||||
|
"""
|
||||||
conditional_embedding = self.clip_text_encoder(text)
|
conditional_embedding = self.clip_text_encoder(text)
|
||||||
if text == negative_text:
|
if text == negative_text:
|
||||||
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
|
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
|
||||||
|
@ -53,10 +84,22 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
|
||||||
|
|
||||||
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
|
||||||
|
"""Set the various context parameters required by the U-Net model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep: The timestep tensor to use.
|
||||||
|
clip_text_embedding: The CLIP text embedding tensor to use.
|
||||||
|
"""
|
||||||
self.unet.set_timestep(timestep=timestep)
|
self.unet.set_timestep(timestep=timestep)
|
||||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||||
|
|
||||||
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||||
|
"""Set whether to enable self-attention guidance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable: Whether to enable self-attention guidance.
|
||||||
|
scale: The scale to use.
|
||||||
|
"""
|
||||||
if enable:
|
if enable:
|
||||||
if sag := self._find_sag_adapter():
|
if sag := self._find_sag_adapter():
|
||||||
sag.scale = scale
|
sag.scale = scale
|
||||||
|
@ -67,9 +110,11 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
sag.eject()
|
sag.eject()
|
||||||
|
|
||||||
def has_self_attention_guidance(self) -> bool:
|
def has_self_attention_guidance(self) -> bool:
|
||||||
|
"""Whether the model has self-attention guidance or not."""
|
||||||
return self._find_sag_adapter() is not None
|
return self._find_sag_adapter() is not None
|
||||||
|
|
||||||
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
|
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
|
||||||
|
"""Finds the self-attention guidance adapter."""
|
||||||
for p in self.unet.get_parents():
|
for p in self.unet.get_parents():
|
||||||
if isinstance(p, SD1SAGAdapter):
|
if isinstance(p, SD1SAGAdapter):
|
||||||
return p
|
return p
|
||||||
|
@ -78,6 +123,17 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
def compute_self_attention_guidance(
|
def compute_self_attention_guidance(
|
||||||
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
"""Compute the self-attention guidance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor.
|
||||||
|
noise: The noise tensor.
|
||||||
|
step: The step to compute the self-attention guidance at.
|
||||||
|
clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The computed self-attention guidance.
|
||||||
|
"""
|
||||||
sag = self._find_sag_adapter()
|
sag = self._find_sag_adapter()
|
||||||
assert sag is not None
|
assert sag is not None
|
||||||
|
|
||||||
|
@ -106,6 +162,14 @@ class StableDiffusion_1(LatentDiffusionModel):
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
|
"""Stable Diffusion 1.5 inpainting model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
unet: The U-Net model.
|
||||||
|
clip_text_encoder: The text encoder.
|
||||||
|
lda: The image autoencoder.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
unet: SD1UNet | None = None,
|
unet: SD1UNet | None = None,
|
||||||
|
@ -140,6 +204,16 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
mask: Image.Image,
|
mask: Image.Image,
|
||||||
latents_size: tuple[int, int] = (64, 64),
|
latents_size: tuple[int, int] = (64, 64),
|
||||||
) -> tuple[Tensor, Tensor]:
|
) -> tuple[Tensor, Tensor]:
|
||||||
|
"""Set the inpainting conditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_image: The target image to inpaint.
|
||||||
|
mask: The mask to use for inpainting.
|
||||||
|
latents_size: The size of the latents to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The mask latents and the target image latents.
|
||||||
|
"""
|
||||||
target_image = target_image.convert(mode="RGB")
|
target_image = target_image.convert(mode="RGB")
|
||||||
mask = mask.convert(mode="L")
|
mask = mask.convert(mode="L")
|
||||||
|
|
||||||
|
@ -156,6 +230,17 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
|
||||||
def compute_self_attention_guidance(
|
def compute_self_attention_guidance(
|
||||||
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
"""Compute the self-attention guidance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor.
|
||||||
|
noise: The noise tensor.
|
||||||
|
step: The step to compute the self-attention guidance at.
|
||||||
|
clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The computed self-attention guidance.
|
||||||
|
"""
|
||||||
sag = self._find_sag_adapter()
|
sag = self._find_sag_adapter()
|
||||||
assert sag is not None
|
assert sag is not None
|
||||||
assert self.mask_latents is not None
|
assert self.mask_latents is not None
|
||||||
|
|
|
@ -25,7 +25,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
Attributes:
|
Attributes:
|
||||||
unet: The U-Net model.
|
unet: The U-Net model.
|
||||||
clip_text_encoder: The text encoder.
|
clip_text_encoder: The text encoder.
|
||||||
lda (SDXLAutoencoder): The image autoencoder.
|
lda: The image autoencoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
unet: SDXLUNet
|
unet: SDXLUNet
|
||||||
|
@ -103,7 +103,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
||||||
time_ids: Tensor,
|
time_ids: Tensor,
|
||||||
**_: Tensor,
|
**_: Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Sets the various context parameters required by the U-Net model.
|
"""Set the various context parameters required by the U-Net model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timestep: The timestep to set.
|
timestep: The timestep to set.
|
||||||
|
|
Loading…
Reference in a new issue