mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
(doc/fluxion/ld) add SDXLAutoencoder
docstrings
This commit is contained in:
parent
0c5a7a8269
commit
effd95a1bd
|
@ -10,10 +10,24 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
|
|||
|
||||
|
||||
class SDXLAutoencoder(LatentDiffusionAutoencoder):
|
||||
"""Stable Diffusion XL autoencoder model.
|
||||
|
||||
Attributes:
|
||||
encoder_scale: The encoder scale to use.
|
||||
"""
|
||||
|
||||
encoder_scale: float = 0.13025
|
||||
|
||||
|
||||
class StableDiffusion_XL(LatentDiffusionModel):
|
||||
"""Stable Diffusion XL model.
|
||||
|
||||
Attributes:
|
||||
unet: The U-Net model.
|
||||
clip_text_encoder: The text encoder.
|
||||
lda (SDXLAutoencoder): The image autoencoder.
|
||||
"""
|
||||
|
||||
unet: SDXLUNet
|
||||
clip_text_encoder: DoubleTextEncoder
|
||||
lda: SDXLAutoencoder
|
||||
|
@ -27,6 +41,16 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
device: Device | str = "cpu",
|
||||
dtype: DType = torch.float32,
|
||||
) -> None:
|
||||
"""Initializes the model.
|
||||
|
||||
Args:
|
||||
unet: The SDXLUNet U-Net model to use.
|
||||
lda: The SDXLAutoencoder image autoencoder to use.
|
||||
clip_text_encoder: The DoubleTextEncoder text encoder to use.
|
||||
solver: The solver to use.
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
unet = unet or SDXLUNet(in_channels=4)
|
||||
lda = lda or SDXLAutoencoder()
|
||||
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
|
||||
|
@ -42,6 +66,13 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
)
|
||||
|
||||
def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
|
||||
"""Compute the CLIP text embedding associated with the given prompt and negative prompt.
|
||||
|
||||
Args:
|
||||
text: The prompt to compute the CLIP text embedding of.
|
||||
negative_text: The negative prompt to compute the CLIP text embedding of.
|
||||
If not provided, the negative prompt is assumed to be empty (i.e., `""`).
|
||||
"""
|
||||
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
|
||||
if text == negative_text:
|
||||
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat(
|
||||
|
@ -57,6 +88,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
|
||||
@property
|
||||
def default_time_ids(self) -> Tensor:
|
||||
"""The default time IDs to use."""
|
||||
# [original_height, original_width, crop_top, crop_left, target_height, target_width]
|
||||
# See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning
|
||||
time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device)
|
||||
|
@ -71,6 +103,14 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
time_ids: Tensor,
|
||||
**_: Tensor,
|
||||
) -> None:
|
||||
"""Sets the various context parameters required by the U-Net model.
|
||||
|
||||
Args:
|
||||
timestep: The timestep to set.
|
||||
clip_text_embedding: The CLIP text embedding to set.
|
||||
pooled_text_embedding: The pooled CLIP text embedding to set.
|
||||
time_ids: The time IDs to set.
|
||||
"""
|
||||
self.unet.set_timestep(timestep=timestep)
|
||||
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)
|
||||
self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding)
|
||||
|
@ -98,6 +138,12 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
)
|
||||
|
||||
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
|
||||
"""Sets the self-attention guidance.
|
||||
|
||||
Args:
|
||||
enable: Whether to enable self-attention guidance or not.
|
||||
scale: The scale to use.
|
||||
"""
|
||||
if enable:
|
||||
if sag := self._find_sag_adapter():
|
||||
sag.scale = scale
|
||||
|
@ -108,9 +154,11 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
sag.eject()
|
||||
|
||||
def has_self_attention_guidance(self) -> bool:
|
||||
"""Whether the model has self-attention guidance or not."""
|
||||
return self._find_sag_adapter() is not None
|
||||
|
||||
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
|
||||
"""Finds the self-attention guidance adapter."""
|
||||
for p in self.unet.get_parents():
|
||||
if isinstance(p, SDXLSAGAdapter):
|
||||
return p
|
||||
|
@ -127,6 +175,19 @@ class StableDiffusion_XL(LatentDiffusionModel):
|
|||
time_ids: Tensor,
|
||||
**kwargs: Tensor,
|
||||
) -> Tensor:
|
||||
"""Compute the self-attention guidance.
|
||||
|
||||
Args:
|
||||
x: The input tensor.
|
||||
noise: The noise tensor.
|
||||
step: The step to compute the self-attention guidance at.
|
||||
clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with.
|
||||
pooled_text_embedding: The pooled CLIP text embedding to compute the self-attention guidance with.
|
||||
time_ids: The time IDs to compute the self-attention guidance with.
|
||||
|
||||
Returns:
|
||||
The computed self-attention guidance.
|
||||
"""
|
||||
sag = self._find_sag_adapter()
|
||||
assert sag is not None
|
||||
|
||||
|
|
Loading…
Reference in a new issue