(doc/fluxion/ld) add StableDiffusion_1 docstrings

This commit is contained in:
Laurent 2024-02-02 12:30:41 +00:00 committed by Laureηt
parent af3a29f916
commit 78e9f7728e
3 changed files with 89 additions and 2 deletions

View file

@ -1,6 +1,7 @@
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import SD1ControlnetAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.image_prompt import SD1IPAdapter
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import (
SD1Autoencoder,
StableDiffusion_1,
StableDiffusion_1_Inpainting,
)
@ -10,6 +11,7 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1U
__all__ = [
"StableDiffusion_1",
"StableDiffusion_1_Inpainting",
"SD1Autoencoder",
"SD1UNet",
"SD1ControlnetAdapter",
"SD1IPAdapter",

View file

@ -13,10 +13,24 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1U
class SD1Autoencoder(LatentDiffusionAutoencoder):
"""Stable Diffusion 1.5 autoencoder model.
Attributes:
encoder_scale: The encoder scale to use.
"""
encoder_scale: float = 0.18215
class StableDiffusion_1(LatentDiffusionModel):
"""Stable Diffusion 1.5 model.
Attributes:
unet: The U-Net model.
clip_text_encoder: The text encoder.
lda: The image autoencoder.
"""
unet: SD1UNet
clip_text_encoder: CLIPTextEncoderL
lda: SD1Autoencoder
@ -30,6 +44,16 @@ class StableDiffusion_1(LatentDiffusionModel):
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
"""Initializes the model.
Args:
unet: The SD1UNet U-Net model to use.
lda: The SD1Autoencoder image autoencoder to use.
clip_text_encoder: The CLIPTextEncoderL text encoder to use.
solver: The solver to use.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
@ -45,6 +69,13 @@ class StableDiffusion_1(LatentDiffusionModel):
)
def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
text: The prompt to compute the CLIP text embedding of.
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
conditional_embedding = self.clip_text_encoder(text)
if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)
@ -53,10 +84,22 @@ class StableDiffusion_1(LatentDiffusionModel):
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
"""Set the various context parameters required by the U-Net model.
Args:
timestep: The timestep tensor to use.
clip_text_embedding: The CLIP text embedding tensor to use.
"""
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
"""Set whether to enable self-attention guidance.
Args:
enable: Whether to enable self-attention guidance.
scale: The scale to use.
"""
if enable:
if sag := self._find_sag_adapter():
sag.scale = scale
@ -67,9 +110,11 @@ class StableDiffusion_1(LatentDiffusionModel):
sag.eject()
def has_self_attention_guidance(self) -> bool:
"""Whether the model has self-attention guidance or not."""
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
"""Finds the self-attention guidance adapter."""
for p in self.unet.get_parents():
if isinstance(p, SD1SAGAdapter):
return p
@ -78,6 +123,17 @@ class StableDiffusion_1(LatentDiffusionModel):
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
"""Compute the self-attention guidance.
Args:
x: The input tensor.
noise: The noise tensor.
step: The step to compute the self-attention guidance at.
clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
Returns:
The computed self-attention guidance.
"""
sag = self._find_sag_adapter()
assert sag is not None
@ -106,6 +162,14 @@ class StableDiffusion_1(LatentDiffusionModel):
class StableDiffusion_1_Inpainting(StableDiffusion_1):
"""Stable Diffusion 1.5 inpainting model.
Attributes:
unet: The U-Net model.
clip_text_encoder: The text encoder.
lda: The image autoencoder.
"""
def __init__(
self,
unet: SD1UNet | None = None,
@ -140,6 +204,16 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
mask: Image.Image,
latents_size: tuple[int, int] = (64, 64),
) -> tuple[Tensor, Tensor]:
"""Set the inpainting conditions.
Args:
target_image: The target image to inpaint.
mask: The mask to use for inpainting.
latents_size: The size of the latents to use.
Returns:
The mask latents and the target image latents.
"""
target_image = target_image.convert(mode="RGB")
mask = mask.convert(mode="L")
@ -156,6 +230,17 @@ class StableDiffusion_1_Inpainting(StableDiffusion_1):
def compute_self_attention_guidance(
self, x: Tensor, noise: Tensor, step: int, *, clip_text_embedding: Tensor, **kwargs: Tensor
) -> Tensor:
"""Compute the self-attention guidance.
Args:
x: The input tensor.
noise: The noise tensor.
step: The step to compute the self-attention guidance at.
clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
Returns:
The computed self-attention guidance.
"""
sag = self._find_sag_adapter()
assert sag is not None
assert self.mask_latents is not None

View file

@ -25,7 +25,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
Attributes:
unet: The U-Net model.
clip_text_encoder: The text encoder.
lda (SDXLAutoencoder): The image autoencoder.
lda: The image autoencoder.
"""
unet: SDXLUNet
@ -103,7 +103,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
time_ids: Tensor,
**_: Tensor,
) -> None:
"""Sets the various context parameters required by the U-Net model.
"""Set the various context parameters required by the U-Net model.
Args:
timestep: The timestep to set.