diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index b272f0a..b3891ea 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -10,10 +10,24 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX class SDXLAutoencoder(LatentDiffusionAutoencoder): + """Stable Diffusion XL autoencoder model. + + Attributes: + encoder_scale: The encoder scale to use. + """ + encoder_scale: float = 0.13025 class StableDiffusion_XL(LatentDiffusionModel): + """Stable Diffusion XL model. + + Attributes: + unet: The U-Net model. + clip_text_encoder: The text encoder. + lda (SDXLAutoencoder): The image autoencoder. + """ + unet: SDXLUNet clip_text_encoder: DoubleTextEncoder lda: SDXLAutoencoder @@ -27,6 +41,16 @@ class StableDiffusion_XL(LatentDiffusionModel): device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: + """Initializes the model. + + Args: + unet: The SDXLUNet U-Net model to use. + lda: The SDXLAutoencoder image autoencoder to use. + clip_text_encoder: The DoubleTextEncoder text encoder to use. + solver: The solver to use. + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ unet = unet or SDXLUNet(in_channels=4) lda = lda or SDXLAutoencoder() clip_text_encoder = clip_text_encoder or DoubleTextEncoder() @@ -42,6 +66,13 @@ class StableDiffusion_XL(LatentDiffusionModel): ) def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]: + """Compute the CLIP text embedding associated with the given prompt and negative prompt. + + Args: + text: The prompt to compute the CLIP text embedding of. + negative_text: The negative prompt to compute the CLIP text embedding of. + If not provided, the negative prompt is assumed to be empty (i.e., `""`). + """ conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text) if text == negative_text: return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0), torch.cat( @@ -57,6 +88,7 @@ class StableDiffusion_XL(LatentDiffusionModel): @property def default_time_ids(self) -> Tensor: + """The default time IDs to use.""" # [original_height, original_width, crop_top, crop_left, target_height, target_width] # See https://arxiv.org/abs/2307.01952 > 2.2 Micro-Conditioning time_ids = torch.tensor(data=[1024, 1024, 0, 0, 1024, 1024], device=self.device) @@ -71,6 +103,14 @@ class StableDiffusion_XL(LatentDiffusionModel): time_ids: Tensor, **_: Tensor, ) -> None: + """Sets the various context parameters required by the U-Net model. + + Args: + timestep: The timestep to set. + clip_text_embedding: The CLIP text embedding to set. + pooled_text_embedding: The pooled CLIP text embedding to set. + time_ids: The time IDs to set. + """ self.unet.set_timestep(timestep=timestep) self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) self.unet.set_pooled_text_embedding(pooled_text_embedding=pooled_text_embedding) @@ -98,6 +138,12 @@ class StableDiffusion_XL(LatentDiffusionModel): ) def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: + """Sets the self-attention guidance. + + Args: + enable: Whether to enable self-attention guidance or not. + scale: The scale to use. + """ if enable: if sag := self._find_sag_adapter(): sag.scale = scale @@ -108,9 +154,11 @@ class StableDiffusion_XL(LatentDiffusionModel): sag.eject() def has_self_attention_guidance(self) -> bool: + """Whether the model has self-attention guidance or not.""" return self._find_sag_adapter() is not None def _find_sag_adapter(self) -> SDXLSAGAdapter | None: + """Finds the self-attention guidance adapter.""" for p in self.unet.get_parents(): if isinstance(p, SDXLSAGAdapter): return p @@ -127,6 +175,19 @@ class StableDiffusion_XL(LatentDiffusionModel): time_ids: Tensor, **kwargs: Tensor, ) -> Tensor: + """Compute the self-attention guidance. + + Args: + x: The input tensor. + noise: The noise tensor. + step: The step to compute the self-attention guidance at. + clip_text_embedding: The CLIP text embedding to compute the self-attention guidance with. + pooled_text_embedding: The pooled CLIP text embedding to compute the self-attention guidance with. + time_ids: The time IDs to compute the self-attention guidance with. + + Returns: + The computed self-attention guidance. + """ sag = self._find_sag_adapter() assert sag is not None