From 289261f2fb8d9265e782ff894f18c7e7171f5cd1 Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 2 Feb 2024 10:15:06 +0000 Subject: [PATCH] (doc/fluxion/ld) add `SD1UNet` docstrings --- .../stable_diffusion_1/unet.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py index 3d8c967..8f102a4 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/unet.py @@ -234,7 +234,23 @@ class ResidualConcatenator(fl.Chain): class SD1UNet(fl.Chain): - def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: + """Stable Diffusion 1.5 U-Net. + + See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details.""" + + def __init__( + self, + in_channels: int, + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize the U-Net. + + Args: + in_channels: The number of input channels. + device: The PyTorch device to use for computation. + dtype: The PyTorch dtype to use for computation. + """ self.in_channels = in_channels super().__init__( TimestepEncoder(device=device, dtype=dtype), @@ -282,7 +298,23 @@ class SD1UNet(fl.Chain): } def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None: + """Set the CLIP text embedding. + + Note: + This context is required by the `CLIPLCrossAttention` blocks. + + Args: + clip_text_embedding: The CLIP text embedding. + """ self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding}) def set_timestep(self, timestep: Tensor) -> None: + """Set the timestep. + + Note: + This context is required by `TimestepEncoder`. + + Args: + timestep: The timestep. + """ self.set_context("diffusion", {"timestep": timestep})