(doc/fluxion/ld) add SDXLUNet docstrings

This commit is contained in:
Laurent 2024-02-02 10:09:11 +00:00 committed by Laureηt
parent f9d4ec18e1
commit f93a145e15

View file

@ -239,7 +239,24 @@ class OutputBlock(fl.Chain):
class SDXLUNet(fl.Chain): class SDXLUNet(fl.Chain):
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None: """Stable Diffusion XL U-Net.
See [[arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952) for more details.
"""
def __init__(
self,
in_channels: int,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initialize the U-Net.
Args:
in_channels: Number of input channels.
device: Device to use for computation.
dtype: Data type to use for computation.
"""
self.in_channels = in_channels self.in_channels = in_channels
super().__init__( super().__init__(
TimestepEncoder(device=device, dtype=dtype), TimestepEncoder(device=device, dtype=dtype),
@ -273,13 +290,45 @@ class SDXLUNet(fl.Chain):
} }
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None: def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
"""Set the clip text embedding context.
Note:
This context is required by the `SDXLCrossAttention` blocks.
Args:
clip_text_embedding: The CLIP text embedding tensor.
"""
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding}) self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
def set_timestep(self, timestep: Tensor) -> None: def set_timestep(self, timestep: Tensor) -> None:
"""Set the timestep context.
Note:
This is required by `TimestepEncoder`.
Args:
timestep: The timestep tensor.
"""
self.set_context(context="diffusion", value={"timestep": timestep}) self.set_context(context="diffusion", value={"timestep": timestep})
def set_time_ids(self, time_ids: Tensor) -> None: def set_time_ids(self, time_ids: Tensor) -> None:
"""Set the time IDs context.
Note:
This is required by `TextTimeEmbedding`.
Args:
time_ids: The time IDs tensor.
"""
self.set_context(context="diffusion", value={"time_ids": time_ids}) self.set_context(context="diffusion", value={"time_ids": time_ids})
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None: def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
"""Set the pooled text embedding context.
Note:
This is required by `TextTimeEmbedding`.
Args:
pooled_text_embedding: The pooled text embedding tensor.
"""
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding}) self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})