mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
(doc/fluxion/ld) add SDXLUNet docstrings
This commit is contained in:
parent
270357ed29
commit
fae08c058e
|
@ -239,7 +239,24 @@ class OutputBlock(fl.Chain):
|
|||
|
||||
|
||||
class SDXLUNet(fl.Chain):
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
"""Stable Diffusion XL U-Net.
|
||||
|
||||
See [[arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952) for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the U-Net.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels.
|
||||
device: Device to use for computation.
|
||||
dtype: Data type to use for computation.
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
super().__init__(
|
||||
TimestepEncoder(device=device, dtype=dtype),
|
||||
|
@ -273,13 +290,45 @@ class SDXLUNet(fl.Chain):
|
|||
}
|
||||
|
||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||
"""Set the clip text embedding context.
|
||||
|
||||
Note:
|
||||
This context is required by the `SDXLCrossAttention` blocks.
|
||||
|
||||
Args:
|
||||
clip_text_embedding: The CLIP text embedding tensor.
|
||||
"""
|
||||
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
|
||||
|
||||
def set_timestep(self, timestep: Tensor) -> None:
|
||||
"""Set the timestep context.
|
||||
|
||||
Note:
|
||||
This is required by `TimestepEncoder`.
|
||||
|
||||
Args:
|
||||
timestep: The timestep tensor.
|
||||
"""
|
||||
self.set_context(context="diffusion", value={"timestep": timestep})
|
||||
|
||||
def set_time_ids(self, time_ids: Tensor) -> None:
|
||||
"""Set the time IDs context.
|
||||
|
||||
Note:
|
||||
This is required by `TextTimeEmbedding`.
|
||||
|
||||
Args:
|
||||
time_ids: The time IDs tensor.
|
||||
"""
|
||||
self.set_context(context="diffusion", value={"time_ids": time_ids})
|
||||
|
||||
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
|
||||
"""Set the pooled text embedding context.
|
||||
|
||||
Note:
|
||||
This is required by `TextTimeEmbedding`.
|
||||
|
||||
Args:
|
||||
pooled_text_embedding: The pooled text embedding tensor.
|
||||
"""
|
||||
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})
|
||||
|
|
Loading…
Reference in a new issue