mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
(doc/fluxion/ld) add LatentDiffusionAutoencoder
docstrings
This commit is contained in:
parent
08c453345a
commit
511221e73d
|
@ -188,6 +188,12 @@ class Decoder(Chain):
|
|||
|
||||
|
||||
class LatentDiffusionAutoencoder(Chain):
|
||||
"""Latent diffusion autoencoder model.
|
||||
|
||||
Attributes:
|
||||
encoder_scale: The encoder scale to use.
|
||||
"""
|
||||
|
||||
encoder_scale = 0.18125
|
||||
|
||||
def __init__(
|
||||
|
@ -195,33 +201,67 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
device: Device | str | None = None,
|
||||
dtype: DType | None = None,
|
||||
) -> None:
|
||||
"""Initializes the model.
|
||||
|
||||
Args:
|
||||
device: The PyTorch device to use.
|
||||
dtype: The PyTorch data type to use.
|
||||
"""
|
||||
super().__init__(
|
||||
Encoder(device=device, dtype=dtype),
|
||||
Decoder(device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
"""Encode an image.
|
||||
|
||||
Args:
|
||||
x: The image tensor to encode.
|
||||
|
||||
Returns:
|
||||
The encoded tensor.
|
||||
"""
|
||||
encoder = self[0]
|
||||
x = self.encoder_scale * encoder(x)
|
||||
return x
|
||||
|
||||
def decode(self, x: Tensor) -> Tensor:
|
||||
"""Decode a latent tensor.
|
||||
|
||||
Args:
|
||||
x: The latent to decode.
|
||||
|
||||
Returns:
|
||||
The decoded image tensor.
|
||||
"""
|
||||
decoder = self[1]
|
||||
x = decoder(x / self.encoder_scale)
|
||||
return x
|
||||
|
||||
# backward-compatibility alias
|
||||
# TODO: deprecate this method
|
||||
def image_to_latents(self, image: Image.Image) -> Tensor:
|
||||
return self.images_to_latents([image])
|
||||
|
||||
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
||||
"""Convert a list of images to latents.
|
||||
|
||||
Args:
|
||||
images: The list of images to convert.
|
||||
|
||||
Returns:
|
||||
A tensor containing the latents associated with the images.
|
||||
"""
|
||||
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
|
||||
x = 2 * x - 1
|
||||
return self.encode(x)
|
||||
|
||||
# backward-compatibility alias
|
||||
# TODO: deprecate this method
|
||||
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||
return self.latents_to_image(x)
|
||||
|
||||
# TODO: deprecated this method ?
|
||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
||||
if x.shape[0] != 1:
|
||||
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
||||
|
@ -229,6 +269,14 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
return self.latents_to_images(x)[0]
|
||||
|
||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
||||
"""Convert a tensor of latents to images.
|
||||
|
||||
Args:
|
||||
x: The tensor of latents to convert.
|
||||
|
||||
Returns:
|
||||
A list of images associated with the latents.
|
||||
"""
|
||||
x = self.decode(x)
|
||||
x = (x + 1) / 2
|
||||
return tensor_to_images(x)
|
||||
|
|
Loading…
Reference in a new issue