mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
Implement Tiled Autoencoder inference to save VRAM
This commit is contained in:
parent
b16cb7fbb6
commit
b42881e54e
|
@ -1,6 +1,13 @@
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Generator, NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from refiners.fluxion import layers as fl
|
||||||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
from refiners.fluxion.context import Contexts
|
from refiners.fluxion.context import Contexts
|
||||||
from refiners.fluxion.layers import (
|
from refiners.fluxion.layers import (
|
||||||
Chain,
|
Chain,
|
||||||
|
@ -15,7 +22,19 @@ from refiners.fluxion.layers import (
|
||||||
Sum,
|
Sum,
|
||||||
Upsample,
|
Upsample,
|
||||||
)
|
)
|
||||||
from refiners.fluxion.utils import images_to_tensor, tensor_to_images
|
from refiners.fluxion.utils import image_to_tensor, images_to_tensor, no_grad, tensor_to_image, tensor_to_images
|
||||||
|
|
||||||
|
|
||||||
|
class _ImageSize(NamedTuple):
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
|
||||||
|
|
||||||
|
class _Tile(NamedTuple):
|
||||||
|
top: int
|
||||||
|
left: int
|
||||||
|
bottom: int
|
||||||
|
right: int
|
||||||
|
|
||||||
|
|
||||||
class Resnet(Sum):
|
class Resnet(Sum):
|
||||||
|
@ -187,6 +206,78 @@ class Decoder(Chain):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FixedGroupNorm(fl.Chain, Adapter[fl.GroupNorm]):
|
||||||
|
"""
|
||||||
|
Adapter for GroupNorm layers to fix the running mean and variance.
|
||||||
|
|
||||||
|
This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers
|
||||||
|
are consistent across tiles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mean: torch.Tensor | None
|
||||||
|
var: torch.Tensor | None
|
||||||
|
|
||||||
|
def __init__(self, target: fl.GroupNorm) -> None:
|
||||||
|
self.mean = None
|
||||||
|
self.var = None
|
||||||
|
with self.setup_adapter(target):
|
||||||
|
super().__init__(fl.Lambda(self.compute_group_norm))
|
||||||
|
|
||||||
|
def compute_group_norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch, channels, height, width = x.shape
|
||||||
|
# Reshape the tensor to apply batch norm to each group separately (to mimic group norm behavior)
|
||||||
|
x = x.reshape(
|
||||||
|
1,
|
||||||
|
batch * self.target.num_groups,
|
||||||
|
int(channels / self.target.num_groups),
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.mean is None or self.var is None:
|
||||||
|
self.var, self.mean = torch.var_mean(x, dim=(0, 2, 3, 4), correction=0)
|
||||||
|
|
||||||
|
result = F.batch_norm(
|
||||||
|
input=x,
|
||||||
|
running_mean=self.mean,
|
||||||
|
running_var=self.var,
|
||||||
|
weight=None,
|
||||||
|
bias=None,
|
||||||
|
training=False,
|
||||||
|
momentum=0,
|
||||||
|
eps=self.target.eps,
|
||||||
|
)
|
||||||
|
result = result.reshape(batch, channels, height, width)
|
||||||
|
return result * self.target.weight.reshape(1, -1, 1, 1) + self.target.bias.reshape(1, -1, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_blending_mask(
|
||||||
|
size: _ImageSize,
|
||||||
|
blending: int,
|
||||||
|
num_channels: int,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
|
||||||
|
) -> torch.Tensor:
|
||||||
|
mask = torch.ones(size, device=device)
|
||||||
|
if blending == 0:
|
||||||
|
return mask
|
||||||
|
blending = min(blending, min(size) // 2)
|
||||||
|
|
||||||
|
ramp = torch.linspace(0, 1, steps=blending, device=device)
|
||||||
|
|
||||||
|
# Apply ramps only if not at the corresponding edge
|
||||||
|
if not is_edge[0]: # top
|
||||||
|
mask[:blending, :] *= ramp.view(-1, 1)
|
||||||
|
if not is_edge[1]: # bottom
|
||||||
|
mask[-blending:, :] *= ramp.flip(0).view(-1, 1)
|
||||||
|
if not is_edge[2]: # left
|
||||||
|
mask[:, :blending] *= ramp.view(1, -1)
|
||||||
|
if not is_edge[3]: # right
|
||||||
|
mask[:, -blending:] *= ramp.flip(0).view(1, -1)
|
||||||
|
|
||||||
|
return mask.unsqueeze(0).unsqueeze(0).expand(1, num_channels, *size)
|
||||||
|
|
||||||
|
|
||||||
class LatentDiffusionAutoencoder(Chain):
|
class LatentDiffusionAutoencoder(Chain):
|
||||||
"""Latent diffusion autoencoder model.
|
"""Latent diffusion autoencoder model.
|
||||||
|
|
||||||
|
@ -211,6 +302,8 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
Encoder(device=device, dtype=dtype),
|
Encoder(device=device, dtype=dtype),
|
||||||
Decoder(device=device, dtype=dtype),
|
Decoder(device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
self._tile_size = None
|
||||||
|
self._blending = None
|
||||||
|
|
||||||
def encode(self, x: Tensor) -> Tensor:
|
def encode(self, x: Tensor) -> Tensor:
|
||||||
"""Encode an image.
|
"""Encode an image.
|
||||||
|
@ -239,8 +332,29 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def image_to_latents(self, image: Image.Image) -> Tensor:
|
def image_to_latents(self, image: Image.Image) -> Tensor:
|
||||||
|
"""
|
||||||
|
Encode an image to latents.
|
||||||
|
"""
|
||||||
return self.images_to_latents([image])
|
return self.images_to_latents([image])
|
||||||
|
|
||||||
|
def tiled_image_to_latents(self, image: Image.Image) -> Tensor:
|
||||||
|
"""
|
||||||
|
Convert an image to latents with gradient blending to smooth tile edges.
|
||||||
|
|
||||||
|
You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.
|
||||||
|
|
||||||
|
```python
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
|
||||||
|
latents = lda.tiled_image_to_latents(sample_image)
|
||||||
|
"""
|
||||||
|
if self._tile_size is None:
|
||||||
|
raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")
|
||||||
|
|
||||||
|
assert self._tile_size is not None and self._blending is not None
|
||||||
|
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||||
|
image_tensor = 2 * image_tensor - 1
|
||||||
|
return self._tiled_encode(image_tensor, self._tile_size, self._blending)
|
||||||
|
|
||||||
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
||||||
"""Convert a list of images to latents.
|
"""Convert a list of images to latents.
|
||||||
|
|
||||||
|
@ -260,11 +374,31 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
return self.latents_to_image(x)
|
return self.latents_to_image(x)
|
||||||
|
|
||||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
def latents_to_image(self, x: Tensor) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Decode latents to an image.
|
||||||
|
"""
|
||||||
if x.shape[0] != 1:
|
if x.shape[0] != 1:
|
||||||
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
||||||
|
|
||||||
return self.latents_to_images(x)[0]
|
return self.latents_to_images(x)[0]
|
||||||
|
|
||||||
|
def tiled_latents_to_image(self, x: Tensor) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Convert latents to an image with gradient blending to smooth tile edges.
|
||||||
|
|
||||||
|
You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.
|
||||||
|
|
||||||
|
```python
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
|
||||||
|
image = lda.tiled_latents_to_image(latents)
|
||||||
|
"""
|
||||||
|
if self._tile_size is None:
|
||||||
|
raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")
|
||||||
|
|
||||||
|
assert self._tile_size is not None and self._blending is not None
|
||||||
|
result = self._tiled_decode(x, self._tile_size, self._blending)
|
||||||
|
return tensor_to_image((result + 1) / 2)
|
||||||
|
|
||||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
||||||
"""Convert a tensor of latents to images.
|
"""Convert a tensor of latents to images.
|
||||||
|
|
||||||
|
@ -277,3 +411,209 @@ class LatentDiffusionAutoencoder(Chain):
|
||||||
x = self.decode(x)
|
x = self.decode(x)
|
||||||
x = (x + 1) / 2
|
x = (x + 1) / 2
|
||||||
return tensor_to_images(x)
|
return tensor_to_images(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_latent_tiles(size: _ImageSize, tile_size: _ImageSize, overlap: int = 8) -> list[_Tile]:
|
||||||
|
"""
|
||||||
|
Generate tiles for a given latent size and tile size with a given overlap.
|
||||||
|
"""
|
||||||
|
tiles: list[_Tile] = []
|
||||||
|
|
||||||
|
for y in range(0, size.width, tile_size.width - overlap):
|
||||||
|
for x in range(0, size.height, tile_size.height - overlap):
|
||||||
|
tile = _Tile(
|
||||||
|
top=max(0, x),
|
||||||
|
left=max(0, y),
|
||||||
|
bottom=min(size.height, x + tile_size.height),
|
||||||
|
right=min(size.width, y + tile_size.width),
|
||||||
|
)
|
||||||
|
tiles.append(tile)
|
||||||
|
|
||||||
|
return tiles
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def _add_fixed_group_norm(self, image: Image.Image, inference_size: _ImageSize) -> None:
|
||||||
|
"""
|
||||||
|
Set the running mean and variance of the group norm layers in the latent diffusion autoencoder.
|
||||||
|
|
||||||
|
We replace the GroupNorm layers with FixedGroupNorm layers that will compute the group norm statistics on its
|
||||||
|
first forward pass and then fix them for all subsequent passes. This is useful when running tiled inference to
|
||||||
|
ensure that the statistics of the GroupNorm layers are consistent across tiles.
|
||||||
|
"""
|
||||||
|
for group_norm, parent in self.walk(fl.GroupNorm):
|
||||||
|
FixedGroupNorm(group_norm).inject(parent)
|
||||||
|
|
||||||
|
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
|
||||||
|
|
||||||
|
image_tensor = image_to_tensor(image, device=self.device)
|
||||||
|
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device)
|
||||||
|
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
|
||||||
|
|
||||||
|
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
|
||||||
|
new_std, new_mean = torch.std_mean(downscaled_image_tensor, dim=[0, 2, 3], keepdim=True)
|
||||||
|
|
||||||
|
downscaled_image_tensor = (downscaled_image_tensor - new_mean) * (std / new_std) + mean
|
||||||
|
downscaled_image_tensor = 2 * downscaled_image_tensor - 1
|
||||||
|
|
||||||
|
# We do a forward pass through the encoder and decoder to set the group norm stats in the FixedGroupNorm layers
|
||||||
|
latents = self.encode(downscaled_image_tensor)
|
||||||
|
self.decode(latents)
|
||||||
|
|
||||||
|
def _remove_fixed_group_norm(self) -> None:
|
||||||
|
"""
|
||||||
|
Remove the FixedGroupNorm layers and restore the original GroupNorm layers.
|
||||||
|
"""
|
||||||
|
for fixed_group_norm in self.layers(FixedGroupNorm):
|
||||||
|
fixed_group_norm.eject()
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def _tiled_encode(self, image_tensor: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Encode an image to latents with tile-based inference and gradient blending to smooth tile edges.
|
||||||
|
|
||||||
|
If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the
|
||||||
|
default tile size of 512x512 is used.
|
||||||
|
"""
|
||||||
|
latent_size = _ImageSize(height=image_tensor.shape[2] // 8, width=image_tensor.shape[3] // 8)
|
||||||
|
target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8)
|
||||||
|
tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8)
|
||||||
|
|
||||||
|
if len(tiles) == 1:
|
||||||
|
return self.encode(image_tensor)
|
||||||
|
|
||||||
|
result = torch.zeros((1, 4, *latent_size), device=self.device)
|
||||||
|
weights = torch.zeros_like(result)
|
||||||
|
|
||||||
|
for latent_tile in tiles:
|
||||||
|
pixel_tile = image_tensor[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
latent_tile.top * 8 : latent_tile.bottom * 8,
|
||||||
|
latent_tile.left * 8 : latent_tile.right * 8,
|
||||||
|
]
|
||||||
|
encoded_tile = self.encode(pixel_tile)
|
||||||
|
|
||||||
|
is_edge = (
|
||||||
|
latent_tile.top == 0,
|
||||||
|
latent_tile.bottom == latent_size.height,
|
||||||
|
latent_tile.left == 0,
|
||||||
|
latent_tile.right == latent_size.width,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_tile_size = _ImageSize(
|
||||||
|
height=(latent_tile.bottom - latent_tile.top), width=(latent_tile.right - latent_tile.left)
|
||||||
|
)
|
||||||
|
|
||||||
|
tile_mask = _create_blending_mask(
|
||||||
|
latent_tile_size,
|
||||||
|
blending // 8,
|
||||||
|
num_channels=4,
|
||||||
|
device=self.device,
|
||||||
|
is_edge=is_edge,
|
||||||
|
)
|
||||||
|
|
||||||
|
result[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
latent_tile.top : latent_tile.bottom,
|
||||||
|
latent_tile.left : latent_tile.right,
|
||||||
|
] += encoded_tile * tile_mask
|
||||||
|
|
||||||
|
weights[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
latent_tile.top : latent_tile.bottom,
|
||||||
|
latent_tile.left : latent_tile.right,
|
||||||
|
] += tile_mask
|
||||||
|
|
||||||
|
return result / weights
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def _tiled_decode(self, latents: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert latents to an image for the given latent diffusion autoencoder, with gradient blending to smooth tile edges.
|
||||||
|
|
||||||
|
If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the
|
||||||
|
default tile size of 512x512 is used.
|
||||||
|
"""
|
||||||
|
latent_size = _ImageSize(height=latents.shape[2], width=latents.shape[3])
|
||||||
|
pixel_size = _ImageSize(height=latent_size.height * 8, width=latent_size.width * 8)
|
||||||
|
target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8)
|
||||||
|
tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8)
|
||||||
|
if len(tiles) == 1:
|
||||||
|
return self.decode(latents)
|
||||||
|
|
||||||
|
result = torch.zeros((1, 3, *pixel_size), device=self.device)
|
||||||
|
weights = torch.zeros_like(result)
|
||||||
|
|
||||||
|
for latent_tile in tiles:
|
||||||
|
pixel_offset = _ImageSize(height=latent_tile.top * 8, width=latent_tile.left * 8)
|
||||||
|
latent_tile_size = _ImageSize(
|
||||||
|
height=latent_tile.bottom - latent_tile.top, width=latent_tile.right - latent_tile.left
|
||||||
|
)
|
||||||
|
pixel_tile_size = _ImageSize(height=latent_tile_size.height * 8, width=latent_tile_size.width * 8)
|
||||||
|
|
||||||
|
pixel_tile = self.decode(
|
||||||
|
latents[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
latent_tile.top : latent_tile.bottom,
|
||||||
|
latent_tile.left : latent_tile.right,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
is_edge = (
|
||||||
|
latent_tile.top == 0,
|
||||||
|
latent_tile.bottom == latent_size.height,
|
||||||
|
latent_tile.left == 0,
|
||||||
|
latent_tile.right == latent_size.width,
|
||||||
|
)
|
||||||
|
|
||||||
|
pixel_tile_mask = _create_blending_mask(
|
||||||
|
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge
|
||||||
|
)
|
||||||
|
result[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
pixel_offset.height : pixel_offset.height + pixel_tile_size.height,
|
||||||
|
pixel_offset.width : pixel_offset.width + pixel_tile_size.width,
|
||||||
|
] += pixel_tile * pixel_tile_mask
|
||||||
|
|
||||||
|
weights[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
pixel_offset.height : pixel_offset.height + pixel_tile_size.height,
|
||||||
|
pixel_offset.width : pixel_offset.width + pixel_tile_size.width,
|
||||||
|
] += pixel_tile_mask
|
||||||
|
|
||||||
|
return result / weights
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def tiled_inference(
|
||||||
|
self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
Context manager for tiled inference operations to save VRAM for large images.
|
||||||
|
|
||||||
|
This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the
|
||||||
|
autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is
|
||||||
|
consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the
|
||||||
|
image.
|
||||||
|
|
||||||
|
Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is
|
||||||
|
active, as this will fail silently and run the operation without tiling.
|
||||||
|
|
||||||
|
```python
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32):
|
||||||
|
latents = lda.tiled_image_to_latents(sample_image)
|
||||||
|
decoded_image = lda.tiled_latents_to_image(latents)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._blending = blending
|
||||||
|
self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1])
|
||||||
|
self._add_fixed_group_norm(image, inference_size=self._tile_size)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._remove_fixed_group_norm()
|
||||||
|
self._tile_size = None
|
||||||
|
self._blending = None
|
||||||
|
|
|
@ -16,7 +16,7 @@ def ref_path() -> Path:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def encoder(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
|
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
|
||||||
lda_weights = test_weights_path / "lda.safetensors"
|
lda_weights = test_weights_path / "lda.safetensors"
|
||||||
if not lda_weights.is_file():
|
if not lda_weights.is_file():
|
||||||
warn(f"could not find weights at {lda_weights}, skipping")
|
warn(f"could not find weights at {lda_weights}, skipping")
|
||||||
|
@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image:
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
encoded = encoder.image_to_latents(sample_image)
|
encoded = lda.image_to_latents(sample_image)
|
||||||
decoded = encoder.latents_to_image(encoded)
|
decoded = lda.latents_to_image(encoded)
|
||||||
|
|
||||||
assert decoded.mode == "RGB" # type: ignore
|
assert decoded.mode == "RGB" # type: ignore
|
||||||
|
|
||||||
|
@ -52,9 +52,62 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image:
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
encoded = encoder.images_to_latents([sample_image, sample_image])
|
encoded = lda.images_to_latents([sample_image, sample_image])
|
||||||
images = encoder.latents_to_images(encoded)
|
images = lda.latents_to_images(encoded)
|
||||||
assert isinstance(images, list)
|
assert isinstance(images, list)
|
||||||
assert len(images) == 2
|
assert len(images) == 2
|
||||||
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
|
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
||||||
|
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
|
||||||
|
encoded = lda.tiled_image_to_latents(sample_image)
|
||||||
|
result = lda.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
||||||
|
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(512, 1024)):
|
||||||
|
encoded = lda.tiled_image_to_latents(sample_image)
|
||||||
|
result = lda.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.resize((1024, 1024)) # type: ignore
|
||||||
|
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(2048, 2048)):
|
||||||
|
encoded = lda.tiled_image_to_latents(sample_image)
|
||||||
|
result = lda.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
|
sample_image = sample_image.crop((0, 0, 300, 500))
|
||||||
|
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
|
||||||
|
|
||||||
|
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
|
||||||
|
encoded = lda.tiled_image_to_latents(sample_image)
|
||||||
|
result = lda.tiled_latents_to_image(encoded)
|
||||||
|
|
||||||
|
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
|
||||||
|
|
||||||
|
|
||||||
|
def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
lda.tiled_image_to_latents(sample_image)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))
|
||||||
|
|
Loading…
Reference in a new issue