(doc/fluxion/ld) add SD1UNet docstrings

This commit is contained in:
Laurent 2024-02-02 10:15:06 +00:00 committed by Laureηt
parent fae08c058e
commit 289261f2fb

View file

@ -234,7 +234,23 @@ class ResidualConcatenator(fl.Chain):
class SD1UNet(fl.Chain): class SD1UNet(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: """Stable Diffusion 1.5 U-Net.
See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details."""
def __init__(
self,
in_channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the U-Net.
Args:
in_channels: The number of input channels.
device: The PyTorch device to use for computation.
dtype: The PyTorch dtype to use for computation.
"""
self.in_channels = in_channels self.in_channels = in_channels
super().__init__( super().__init__(
TimestepEncoder(device=device, dtype=dtype), TimestepEncoder(device=device, dtype=dtype),
@ -282,7 +298,23 @@ class SD1UNet(fl.Chain):
} }
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None: def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
"""Set the CLIP text embedding.
Note:
This context is required by the `CLIPLCrossAttention` blocks.
Args:
clip_text_embedding: The CLIP text embedding.
"""
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding}) self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None: def set_timestep(self, timestep: Tensor) -> None:
"""Set the timestep.
Note:
This context is required by `TimestepEncoder`.
Args:
timestep: The timestep.
"""
self.set_context("diffusion", {"timestep": timestep}) self.set_context("diffusion", {"timestep": timestep})