(doc/fluxion/ld) add SDXLAutoencoder docstrings

This commit is contained in:
Laurent 2024-02-02 11:11:41 +00:00 committed by Laureηt
parent 7309a0985e
commit 08c453345a

View file

@ -10,10 +10,24 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
class SDXLAutoencoder(LatentDiffusionAutoencoder): class SDXLAutoencoder(LatentDiffusionAutoencoder):
"""Stable Diffusion XL autoencoder model.
Attributes:
encoder_scale: The encoder scale to use.
"""
encoder_scale: float = 0.13025 encoder_scale: float = 0.13025
class StableDiffusion_XL(LatentDiffusionModel): class StableDiffusion_XL(LatentDiffusionModel):
"""Stable Diffusion XL model.
Attributes:
unet: The U-Net model.
clip_text_encoder: The text encoder.
lda (SDXLAutoencoder): The image autoencoder.
"""
unet: SDXLUNet unet: SDXLUNet
clip_text_encoder: DoubleTextEncoder clip_text_encoder: DoubleTextEncoder
lda: SDXLAutoencoder lda: SDXLAutoencoder
@ -27,6 +41,16 @@ class StableDiffusion_XL(LatentDiffusionModel):
device: Device | str = "cpu", device: Device | str = "cpu",
dtype: DType = torch.float32, dtype: DType = torch.float32,
) -> None: ) -> None:
"""Initializes the model.
Args:
unet: The SDXLUNet U-Net model to use.
lda: The SDXLAutoencoder image autoencoder to use.
clip_text_encoder: The DoubleTextEncoder text encoder to use.
solver: The solver to use.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
unet = unet or SDXLUNet(in_channels=4) unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder() lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder() clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
@ -42,6 +66,13 @@ class StableDiffusion_XL(LatentDiffusionModel):
) )
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]: def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
Args:
text: The prompt to compute the CLIP text embedding of.
negative_text: The negative prompt to compute the CLIP text embedding of.
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
"""
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text) conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
if text == negative_text: if text == negative_text:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat( return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
@ -57,6 +88,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
@property @property
def default_time_ids(self) -> Tensor: def default_time_ids(self) -> Tensor:
"""The default time IDs to use."""
# [original_height, original_width, crop_top, crop_left, target_height, target_width] # [original_height, original_width, crop_top, crop_left, target_height, target_width]
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning # See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device) time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
@ -71,6 +103,14 @@ class StableDiffusion_XL(LatentDiffusionModel):
time_ids: Tensor, time_ids: Tensor,
**_: Tensor, **_: Tensor,
) -> None: ) -> None:
"""Sets the various context parameters required by the U-Net model.
Args:
timestep: The timestep to set.
clip_text_embedding: The CLIP text embedding to set.
pooled_text_embedding: The pooled CLIP text embedding to set.
time_ids: The time IDs to set.
"""
self.unet.set_timestep(timestep=timestep) self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding) self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
@ -98,6 +138,12 @@ class StableDiffusion_XL(LatentDiffusionModel):
) )
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
"""Sets the self-attention guidance.
Args:
enable: Whether to enable self-attention guidance or not.
scale: The scale to use.
"""
if enable: if enable:
if sag := self._find_sag_adapter(): if sag := self._find_sag_adapter():
sag.scale = scale sag.scale = scale
@ -108,9 +154,11 @@ class StableDiffusion_XL(LatentDiffusionModel):
sag.eject() sag.eject()
def has_self_attention_guidance(self) -> bool: def has_self_attention_guidance(self) -> bool:
"""Whether the model has self-attention guidance or not."""
return self._find_sag_adapter() is not None return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SDXLSAGAdapter | None: def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
"""Finds the self-attention guidance adapter."""
for p in self.unet.get_parents(): for p in self.unet.get_parents():
if isinstance(p, SDXLSAGAdapter): if isinstance(p, SDXLSAGAdapter):
return p return p
@ -127,6 +175,19 @@ class StableDiffusion_XL(LatentDiffusionModel):
time_ids: Tensor, time_ids: Tensor,
**kwargs: Tensor, **kwargs: Tensor,
) -> Tensor: ) -> Tensor:
"""Compute the self-attention guidance.
Args:
x: The input tensor.
noise: The noise tensor.
step: The step to compute the self-attention guidance at.
clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
pooled_text_embedding: The pooled CLIP text embedding to compute the self-attention guidance with.
time_ids: The time IDs to compute the self-attention guidance with.
Returns:
The computed self-attention guidance.
"""
sag = self._find_sag_adapter() sag = self._find_sag_adapter()
assert sag is not None assert sag is not None