From 2a7b86ac02a7d7a48cab7d16ac2154be19047b9c Mon Sep 17 00:00:00 2001 From: Laurent Date: Fri, 2 Feb 2024 11:11:50 +0000 Subject: [PATCH] (doc/fluxion/ld) add `LatentDiffusionAutoencoder` docstrings --- .../latent_diffusion/auto_encoder.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index 70294f5..8b9e999 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -188,6 +188,12 @@ class Decoder(Chain): class LatentDiffusionAutoencoder(Chain): + """Latent diffusion autoencoder model. + + Attributes: + encoder_scale: The encoder scale to use. + """ + encoder_scale = 0.18125 def __init__( @@ -195,33 +201,67 @@ class LatentDiffusionAutoencoder(Chain): device: Device | str | None = None, dtype: DType | None = None, ) -> None: + """Initializes the model. + + Args: + device: The PyTorch device to use. + dtype: The PyTorch data type to use. + """ super().__init__( Encoder(device=device, dtype=dtype), Decoder(device=device, dtype=dtype), ) def encode(self, x: Tensor) -> Tensor: + """Encode an image. + + Args: + x: The image tensor to encode. + + Returns: + The encoded tensor. + """ encoder = self[0] x = self.encoder_scale * encoder(x) return x def decode(self, x: Tensor) -> Tensor: + """Decode a latent tensor. + + Args: + x: The latent to decode. + + Returns: + The decoded image tensor. + """ decoder = self[1] x = decoder(x / self.encoder_scale) return x + # backward-compatibility alias + # TODO: deprecate this method def image_to_latents(self, image: Image.Image) -> Tensor: return self.images_to_latents([image]) def images_to_latents(self, images: list[Image.Image]) -> Tensor: + """Convert a list of images to latents. + + Args: + images: The list of images to convert. + + Returns: + A tensor containing the latents associated with the images. + """ x = images_to_tensor(images, device=self.device, dtype=self.dtype) x = 2 * x - 1 return self.encode(x) # backward-compatibility alias + # TODO: deprecate this method def decode_latents(self, x: Tensor) -> Image.Image: return self.latents_to_image(x) + # TODO: deprecated this method ? def latents_to_image(self, x: Tensor) -> Image.Image: if x.shape[0] != 1: raise ValueError(f"Expected batch size of 1, got {x.shape[0]}") @@ -229,6 +269,14 @@ class LatentDiffusionAutoencoder(Chain): return self.latents_to_images(x)[0] def latents_to_images(self, x: Tensor) -> list[Image.Image]: + """Convert a tensor of latents to images. + + Args: + x: The tensor of latents to convert. + + Returns: + A list of images associated with the latents. + """ x = self.decode(x) x = (x + 1) / 2 return tensor_to_images(x)