mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
(doc/fluxion/ld) add SDXLUNet docstrings
This commit is contained in:
parent
270357ed29
commit
fae08c058e
|
@ -239,7 +239,24 @@ class OutputBlock(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class SDXLUNet(fl.Chain):
|
class SDXLUNet(fl.Chain):
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
"""Stable Diffusion XL U-Net.
|
||||||
|
|
||||||
|
See [[arXiv:2307.01952] SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://arxiv.org/abs/2307.01952) for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the U-Net.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels.
|
||||||
|
device: Device to use for computation.
|
||||||
|
dtype: Data type to use for computation.
|
||||||
|
"""
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(device=device, dtype=dtype),
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
|
@ -273,13 +290,45 @@ class SDXLUNet(fl.Chain):
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
"""Set the clip text embedding context.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This context is required by the `SDXLCrossAttention` blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clip_text_embedding: The CLIP text embedding tensor.
|
||||||
|
"""
|
||||||
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
|
self.set_context(context="cross_attention_block", value={"clip_text_embedding": clip_text_embedding})
|
||||||
|
|
||||||
def set_timestep(self, timestep: Tensor) -> None:
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
"""Set the timestep context.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is required by `TimestepEncoder`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep: The timestep tensor.
|
||||||
|
"""
|
||||||
self.set_context(context="diffusion", value={"timestep": timestep})
|
self.set_context(context="diffusion", value={"timestep": timestep})
|
||||||
|
|
||||||
def set_time_ids(self, time_ids: Tensor) -> None:
|
def set_time_ids(self, time_ids: Tensor) -> None:
|
||||||
|
"""Set the time IDs context.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is required by `TextTimeEmbedding`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_ids: The time IDs tensor.
|
||||||
|
"""
|
||||||
self.set_context(context="diffusion", value={"time_ids": time_ids})
|
self.set_context(context="diffusion", value={"time_ids": time_ids})
|
||||||
|
|
||||||
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
|
def set_pooled_text_embedding(self, pooled_text_embedding: Tensor) -> None:
|
||||||
|
"""Set the pooled text embedding context.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is required by `TextTimeEmbedding`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pooled_text_embedding: The pooled text embedding tensor.
|
||||||
|
"""
|
||||||
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})
|
self.set_context(context="diffusion", value={"pooled_text_embedding": pooled_text_embedding})
|
||||||
|
|
Loading…
Reference in a new issue