mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
(doc/fluxion/ld) add LatentDiffusionAutoencoder
docstrings
This commit is contained in:
parent
effd95a1bd
commit
2a7b86ac02
|
@ -188,6 +188,12 @@ class Decoder(Chain):
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionAutoencoder(Chain):
|
class LatentDiffusionAutoencoder(Chain):
|
||||||
|
"""Latent diffusion autoencoder model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
encoder_scale: The encoder scale to use.
|
||||||
|
"""
|
||||||
|
|
||||||
encoder_scale = 0.18125
|
encoder_scale = 0.18125
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -195,33 +201,67 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
device: Device | str | None = None,
|
device: Device | str | None = None,
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initializes the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: The PyTorch device to use.
|
||||||
|
dtype: The PyTorch data type to use.
|
||||||
|
"""
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Encoder(device=device, dtype=dtype),
|
Encoder(device=device, dtype=dtype),
|
||||||
Decoder(device=device, dtype=dtype),
|
Decoder(device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, x: Tensor) -> Tensor:
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
|
"""Encode an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The image tensor to encode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The encoded tensor.
|
||||||
|
"""
|
||||||
encoder = self[0]
|
encoder = self[0]
|
||||||
x = self.encoder_scale * encoder(x)
|
x = self.encoder_scale * encoder(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def decode(self, x: Tensor) -> Tensor:
|
def decode(self, x: Tensor) -> Tensor:
|
||||||
|
"""Decode a latent tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The latent to decode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The decoded image tensor.
|
||||||
|
"""
|
||||||
decoder = self[1]
|
decoder = self[1]
|
||||||
x = decoder(x / self.encoder_scale)
|
x = decoder(x / self.encoder_scale)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# backward-compatibility alias
|
||||||
|
# TODO: deprecate this method
|
||||||
def image_to_latents(self, image: Image.Image) -> Tensor:
|
def image_to_latents(self, image: Image.Image) -> Tensor:
|
||||||
return self.images_to_latents([image])
|
return self.images_to_latents([image])
|
||||||
|
|
||||||
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
||||||
|
"""Convert a list of images to latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: The list of images to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor containing the latents associated with the images.
|
||||||
|
"""
|
||||||
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
|
x = images_to_tensor(images, device=self.device, dtype=self.dtype)
|
||||||
x = 2 * x - 1
|
x = 2 * x - 1
|
||||||
return self.encode(x)
|
return self.encode(x)
|
||||||
|
|
||||||
# backward-compatibility alias
|
# backward-compatibility alias
|
||||||
|
# TODO: deprecate this method
|
||||||
def decode_latents(self, x: Tensor) -> Image.Image:
|
def decode_latents(self, x: Tensor) -> Image.Image:
|
||||||
return self.latents_to_image(x)
|
return self.latents_to_image(x)
|
||||||
|
|
||||||
|
# TODO: deprecated this method ?
|
||||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
def latents_to_image(self, x: Tensor) -> Image.Image:
|
||||||
if x.shape[0] != 1:
|
if x.shape[0] != 1:
|
||||||
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
||||||
|
@ -229,6 +269,14 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
return self.latents_to_images(x)[0]
|
return self.latents_to_images(x)[0]
|
||||||
|
|
||||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
||||||
|
"""Convert a tensor of latents to images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The tensor of latents to convert.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of images associated with the latents.
|
||||||
|
"""
|
||||||
x = self.decode(x)
|
x = self.decode(x)
|
||||||
x = (x + 1) / 2
|
x = (x + 1) / 2
|
||||||
return tensor_to_images(x)
|
return tensor_to_images(x)
|
||||||
|
|
Loading…
Reference in a new issue