mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
(doc/fluxion/ld) add SD1UNet
docstrings
This commit is contained in:
parent
f93a145e15
commit
828ca0b178
|
@ -234,7 +234,23 @@ class ResidualConcatenator(fl.Chain):
|
|||
|
||||
|
||||
class SD1UNet(fl.Chain):
|
||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
||||
"""Stable Diffusion 1.5 U-Net.
|
||||
|
||||
See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initialize the U-Net.
|
||||
|
||||
Args:
|
||||
in_channels: The number of input channels.
|
||||
device: The PyTorch device to use for computation.
|
||||
dtype: The PyTorch dtype to use for computation.
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
super().__init__(
|
||||
TimestepEncoder(device=device, dtype=dtype),
|
||||
|
@ -282,7 +298,23 @@ class SD1UNet(fl.Chain):
|
|||
}
|
||||
|
||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||
"""Set the CLIP text embedding.
|
||||
|
||||
Note:
|
||||
This context is required by the `CLIPLCrossAttention` blocks.
|
||||
|
||||
Args:
|
||||
clip_text_embedding: The CLIP text embedding.
|
||||
"""
|
||||
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
|
||||
|
||||
def set_timestep(self, timestep: Tensor) -> None:
|
||||
"""Set the timestep.
|
||||
|
||||
Note:
|
||||
This context is required by `TimestepEncoder`.
|
||||
|
||||
Args:
|
||||
timestep: The timestep.
|
||||
"""
|
||||
self.set_context("diffusion", {"timestep": timestep})
|
||||
|
|
Loading…
Reference in a new issue