From fae08c058e40c0022d9d20cc0db7c5d9160942d0 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 2 Feb 2024 10:09:11 +0000 Subject: [PATCH] (doc/fluxion/ld) add SDXLUNet docstrings --- .../stable_diffusion_xl/unet.py | 51 ++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index afab21a..6514442 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -239,7 +239,24 @@ class OutputBlock(fl.Chain): class SDXLUNet(fl.Chain): - def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: + """Stable Diffusion XL U-Net. + + See [[arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952) for more details. + """ + + def __init__( + self, + in_channels: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize the U-Net. + + Args: + in_channels: Number of input channels. + device: Device to use for computation. + dtype: Data type to use for computation. + """ self.in_channels = in_channels super().__init__( TimestepEncoder(device=device, dtype=dtype), @@ -273,13 +290,45 @@ class SDXLUNet(fl.Chain): } def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None: + """Set the clip text embedding context. + + Note: + This context is required by the `SDXLCrossAttention` blocks. + + Args: + clip_text_embedding: The CLIP text embedding tensor. + """ self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding}) def set_timestep(self, timestep: Tensor) -> None: + """Set the timestep context. + + Note: + This is required by `TimestepEncoder`. + + Args: + timestep: The timestep tensor. + """ self.set_context(context="diffusion", value={"timestep": timestep}) def set_time_ids(self, time_ids: Tensor) -> None: + """Set the time IDs context. + + Note: + This is required by `TextTimeEmbedding`. + + Args: + time_ids: The time IDs tensor. + """ self.set_context(context="diffusion", value={"time_ids": time_ids}) def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None: + """Set the pooled text embedding context. + + Note: + This is required by `TextTimeEmbedding`. + + Args: + pooled_text_embedding: The pooled text embedding tensor. + """ self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})