(doc/fluxion/ld) add LatentDiffusionAutoencoder docstrings

This commit is contained in:
Laurent 2024-02-02 11:11:50 +00:00 committed by Laureηt
parent 08c453345a
commit 511221e73d

View file

@ -188,6 +188,12 @@ class Decoder(Chain):
class LatentDiffusionAutoencoder(Chain):
"""Latent diffusion autoencoder model.
Attributes:
encoder_scale: The encoder scale to use.
"""
encoder_scale = 0.18125
def __init__(
@ -195,33 +201,67 @@ class LatentDiffusionAutoencoder(Chain):
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
"""Initializes the model.
Args:
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
super().__init__(
Encoder(device=device, dtype=dtype),
Decoder(device=device, dtype=dtype),
)
def encode(self, x: Tensor) -> Tensor:
"""Encode an image.
Args:
x: The image tensor to encode.
Returns:
The encoded tensor.
"""
encoder = self[0]
x = self.encoder_scale * encoder(x)
return x
def decode(self, x: Tensor) -> Tensor:
"""Decode a latent tensor.
Args:
x: The latent to decode.
Returns:
The decoded image tensor.
"""
decoder = self[1]
x = decoder(x / self.encoder_scale)
return x
# backward-compatibility alias
# TODO: deprecate this method
def image_to_latents(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image])
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
"""Convert a list of images to latents.
Args:
images: The list of images to convert.
Returns:
A tensor containing the latents associated with the images.
"""
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
x = 2 * x - 1
return self.encode(x)
# backward-compatibility alias
# TODO: deprecate this method
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)
# TODO: deprecated this method ?
def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
@ -229,6 +269,14 @@ class LatentDiffusionAutoencoder(Chain):
return self.latents_to_images(x)[0]
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
"""Convert a tensor of latents to images.
Args:
x: The tensor of latents to convert.
Returns:
A list of images associated with the latents.
"""
x = self.decode(x)
x = (x + 1) / 2
return tensor_to_images(x)