mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
(doc/fluxion/ld) add SD1UNet
docstrings
This commit is contained in:
parent
fae08c058e
commit
289261f2fb
|
@ -234,7 +234,23 @@ class ResidualConcatenator(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class SD1UNet(fl.Chain):
|
class SD1UNet(fl.Chain):
|
||||||
def __init__(self, in_channels: int, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
"""Stable Diffusion 1.5 U-Net.
|
||||||
|
|
||||||
|
See [[arXiv:2112.10752] High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) for more details."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the U-Net.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels: The number of input channels.
|
||||||
|
device: The PyTorch device to use for computation.
|
||||||
|
dtype: The PyTorch dtype to use for computation.
|
||||||
|
"""
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
super().__init__(
|
super().__init__(
|
||||||
TimestepEncoder(device=device, dtype=dtype),
|
TimestepEncoder(device=device, dtype=dtype),
|
||||||
|
@ -282,7 +298,23 @@ class SD1UNet(fl.Chain):
|
||||||
}
|
}
|
||||||
|
|
||||||
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
def set_clip_text_embedding(self, clip_text_embedding: Tensor) -> None:
|
||||||
|
"""Set the CLIP text embedding.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This context is required by the `CLIPLCrossAttention` blocks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clip_text_embedding: The CLIP text embedding.
|
||||||
|
"""
|
||||||
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
|
self.set_context("cross_attention_block", {"clip_text_embedding": clip_text_embedding})
|
||||||
|
|
||||||
def set_timestep(self, timestep: Tensor) -> None:
|
def set_timestep(self, timestep: Tensor) -> None:
|
||||||
|
"""Set the timestep.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This context is required by `TimestepEncoder`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep: The timestep.
|
||||||
|
"""
|
||||||
self.set_context("diffusion", {"timestep": timestep})
|
self.set_context("diffusion", {"timestep": timestep})
|
||||||
|
|
Loading…
Reference in a new issue