Implement Tiled Autoencoder inference to save VRAM

This commit is contained in:
limiteinductive 2024-06-24 16:35:43 +00:00 committed by Benjamin Trom
parent b16cb7fbb6
commit b42881e54e
2 changed files with 401 additions and 8 deletions

View file

@ -1,6 +1,13 @@
from contextlib import contextmanager
from typing import Generator, NamedTuple
import torch
from PIL import Image from PIL import Image
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
from torch.nn import functional as F
from refiners.fluxion import layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts from refiners.fluxion.context import Contexts
from refiners.fluxion.layers import ( from refiners.fluxion.layers import (
Chain, Chain,
@ -15,7 +22,19 @@ from refiners.fluxion.layers import (
Sum, Sum,
Upsample, Upsample,
) )
from refiners.fluxion.utils import images_to_tensor, tensor_to_images from refiners.fluxion.utils import image_to_tensor, images_to_tensor, no_grad, tensor_to_image, tensor_to_images
class _ImageSize(NamedTuple):
height: int
width: int
class _Tile(NamedTuple):
top: int
left: int
bottom: int
right: int
class Resnet(Sum): class Resnet(Sum):
@ -187,6 +206,78 @@ class Decoder(Chain):
) )
class FixedGroupNorm(fl.Chain, Adapter[fl.GroupNorm]):
"""
Adapter for GroupNorm layers to fix the running mean and variance.
This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers
are consistent across tiles.
"""
mean: torch.Tensor | None
var: torch.Tensor | None
def __init__(self, target: fl.GroupNorm) -> None:
self.mean = None
self.var = None
with self.setup_adapter(target):
super().__init__(fl.Lambda(self.compute_group_norm))
def compute_group_norm(self, x: torch.Tensor) -> torch.Tensor:
batch, channels, height, width = x.shape
# Reshape the tensor to apply batch norm to each group separately (to mimic group norm behavior)
x = x.reshape(
1,
batch * self.target.num_groups,
int(channels / self.target.num_groups),
height,
width,
)
if self.mean is None or self.var is None:
self.var, self.mean = torch.var_mean(x, dim=(0, 2, 3, 4), correction=0)
result = F.batch_norm(
input=x,
running_mean=self.mean,
running_var=self.var,
weight=None,
bias=None,
training=False,
momentum=0,
eps=self.target.eps,
)
result = result.reshape(batch, channels, height, width)
return result * self.target.weight.reshape(1, -1, 1, 1) + self.target.bias.reshape(1, -1, 1, 1)
def _create_blending_mask(
size: _ImageSize,
blending: int,
num_channels: int,
device: torch.device | None = None,
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
) -> torch.Tensor:
mask = torch.ones(size, device=device)
if blending == 0:
return mask
blending = min(blending, min(size) // 2)
ramp = torch.linspace(0, 1, steps=blending, device=device)
# Apply ramps only if not at the corresponding edge
if not is_edge[0]: # top
mask[:blending, :] *= ramp.view(-1, 1)
if not is_edge[1]: # bottom
mask[-blending:, :] *= ramp.flip(0).view(-1, 1)
if not is_edge[2]: # left
mask[:, :blending] *= ramp.view(1, -1)
if not is_edge[3]: # right
mask[:, -blending:] *= ramp.flip(0).view(1, -1)
return mask.unsqueeze(0).unsqueeze(0).expand(1, num_channels, *size)
class LatentDiffusionAutoencoder(Chain): class LatentDiffusionAutoencoder(Chain):
"""Latent diffusion autoencoder model. """Latent diffusion autoencoder model.
@ -211,6 +302,8 @@ class LatentDiffusionAutoencoder(Chain):
Encoder(device=device, dtype=dtype), Encoder(device=device, dtype=dtype),
Decoder(device=device, dtype=dtype), Decoder(device=device, dtype=dtype),
) )
self._tile_size = None
self._blending = None
def encode(self, x: Tensor) -> Tensor: def encode(self, x: Tensor) -> Tensor:
"""Encode an image. """Encode an image.
@ -239,8 +332,29 @@ class LatentDiffusionAutoencoder(Chain):
return x return x
def image_to_latents(self, image: Image.Image) -> Tensor: def image_to_latents(self, image: Image.Image) -> Tensor:
"""
Encode an image to latents.
"""
return self.images_to_latents([image]) return self.images_to_latents([image])
def tiled_image_to_latents(self, image: Image.Image) -> Tensor:
"""
Convert an image to latents with gradient blending to smooth tile edges.
You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.
```python
with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
latents = lda.tiled_image_to_latents(sample_image)
"""
if self._tile_size is None:
raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")
assert self._tile_size is not None and self._blending is not None
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
image_tensor = 2 * image_tensor - 1
return self._tiled_encode(image_tensor, self._tile_size, self._blending)
def images_to_latents(self, images: list[Image.Image]) -> Tensor: def images_to_latents(self, images: list[Image.Image]) -> Tensor:
"""Convert a list of images to latents. """Convert a list of images to latents.
@ -260,11 +374,31 @@ class LatentDiffusionAutoencoder(Chain):
return self.latents_to_image(x) return self.latents_to_image(x)
def latents_to_image(self, x: Tensor) -> Image.Image: def latents_to_image(self, x: Tensor) -> Image.Image:
"""
Decode latents to an image.
"""
if x.shape[0] != 1: if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}") raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
return self.latents_to_images(x)[0] return self.latents_to_images(x)[0]
def tiled_latents_to_image(self, x: Tensor) -> Image.Image:
"""
Convert latents to an image with gradient blending to smooth tile edges.
You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.
```python
with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
image = lda.tiled_latents_to_image(latents)
"""
if self._tile_size is None:
raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")
assert self._tile_size is not None and self._blending is not None
result = self._tiled_decode(x, self._tile_size, self._blending)
return tensor_to_image((result + 1) / 2)
def latents_to_images(self, x: Tensor) -> list[Image.Image]: def latents_to_images(self, x: Tensor) -> list[Image.Image]:
"""Convert a tensor of latents to images. """Convert a tensor of latents to images.
@ -277,3 +411,209 @@ class LatentDiffusionAutoencoder(Chain):
x = self.decode(x) x = self.decode(x)
x = (x + 1) / 2 x = (x + 1) / 2
return tensor_to_images(x) return tensor_to_images(x)
@staticmethod
def _generate_latent_tiles(size: _ImageSize, tile_size: _ImageSize, overlap: int = 8) -> list[_Tile]:
"""
Generate tiles for a given latent size and tile size with a given overlap.
"""
tiles: list[_Tile] = []
for y in range(0, size.width, tile_size.width - overlap):
for x in range(0, size.height, tile_size.height - overlap):
tile = _Tile(
top=max(0, x),
left=max(0, y),
bottom=min(size.height, x + tile_size.height),
right=min(size.width, y + tile_size.width),
)
tiles.append(tile)
return tiles
@no_grad()
def _add_fixed_group_norm(self, image: Image.Image, inference_size: _ImageSize) -> None:
"""
Set the running mean and variance of the group norm layers in the latent diffusion autoencoder.
We replace the GroupNorm layers with FixedGroupNorm layers that will compute the group norm statistics on its
first forward pass and then fix them for all subsequent passes. This is useful when running tiled inference to
ensure that the statistics of the GroupNorm layers are consistent across tiles.
"""
for group_norm, parent in self.walk(fl.GroupNorm):
FixedGroupNorm(group_norm).inject(parent)
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
image_tensor = image_to_tensor(image, device=self.device)
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device)
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
new_std, new_mean = torch.std_mean(downscaled_image_tensor, dim=[0, 2, 3], keepdim=True)
downscaled_image_tensor = (downscaled_image_tensor - new_mean) * (std / new_std) + mean
downscaled_image_tensor = 2 * downscaled_image_tensor - 1
# We do a forward pass through the encoder and decoder to set the group norm stats in the FixedGroupNorm layers
latents = self.encode(downscaled_image_tensor)
self.decode(latents)
def _remove_fixed_group_norm(self) -> None:
"""
Remove the FixedGroupNorm layers and restore the original GroupNorm layers.
"""
for fixed_group_norm in self.layers(FixedGroupNorm):
fixed_group_norm.eject()
@no_grad()
def _tiled_encode(self, image_tensor: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor:
"""
Encode an image to latents with tile-based inference and gradient blending to smooth tile edges.
If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the
default tile size of 512x512 is used.
"""
latent_size = _ImageSize(height=image_tensor.shape[2] // 8, width=image_tensor.shape[3] // 8)
target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8)
tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8)
if len(tiles) == 1:
return self.encode(image_tensor)
result = torch.zeros((1, 4, *latent_size), device=self.device)
weights = torch.zeros_like(result)
for latent_tile in tiles:
pixel_tile = image_tensor[
:,
:,
latent_tile.top * 8 : latent_tile.bottom * 8,
latent_tile.left * 8 : latent_tile.right * 8,
]
encoded_tile = self.encode(pixel_tile)
is_edge = (
latent_tile.top == 0,
latent_tile.bottom == latent_size.height,
latent_tile.left == 0,
latent_tile.right == latent_size.width,
)
latent_tile_size = _ImageSize(
height=(latent_tile.bottom - latent_tile.top), width=(latent_tile.right - latent_tile.left)
)
tile_mask = _create_blending_mask(
latent_tile_size,
blending // 8,
num_channels=4,
device=self.device,
is_edge=is_edge,
)
result[
:,
:,
latent_tile.top : latent_tile.bottom,
latent_tile.left : latent_tile.right,
] += encoded_tile * tile_mask
weights[
:,
:,
latent_tile.top : latent_tile.bottom,
latent_tile.left : latent_tile.right,
] += tile_mask
return result / weights
@no_grad()
def _tiled_decode(self, latents: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor:
"""
Convert latents to an image for the given latent diffusion autoencoder, with gradient blending to smooth tile edges.
If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the
default tile size of 512x512 is used.
"""
latent_size = _ImageSize(height=latents.shape[2], width=latents.shape[3])
pixel_size = _ImageSize(height=latent_size.height * 8, width=latent_size.width * 8)
target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8)
tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8)
if len(tiles) == 1:
return self.decode(latents)
result = torch.zeros((1, 3, *pixel_size), device=self.device)
weights = torch.zeros_like(result)
for latent_tile in tiles:
pixel_offset = _ImageSize(height=latent_tile.top * 8, width=latent_tile.left * 8)
latent_tile_size = _ImageSize(
height=latent_tile.bottom - latent_tile.top, width=latent_tile.right - latent_tile.left
)
pixel_tile_size = _ImageSize(height=latent_tile_size.height * 8, width=latent_tile_size.width * 8)
pixel_tile = self.decode(
latents[
:,
:,
latent_tile.top : latent_tile.bottom,
latent_tile.left : latent_tile.right,
]
)
is_edge = (
latent_tile.top == 0,
latent_tile.bottom == latent_size.height,
latent_tile.left == 0,
latent_tile.right == latent_size.width,
)
pixel_tile_mask = _create_blending_mask(
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge
)
result[
:,
:,
pixel_offset.height : pixel_offset.height + pixel_tile_size.height,
pixel_offset.width : pixel_offset.width + pixel_tile_size.width,
] += pixel_tile * pixel_tile_mask
weights[
:,
:,
pixel_offset.height : pixel_offset.height + pixel_tile_size.height,
pixel_offset.width : pixel_offset.width + pixel_tile_size.width,
] += pixel_tile_mask
return result / weights
@contextmanager
def tiled_inference(
self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64
) -> Generator[None, None, None]:
"""
Context manager for tiled inference operations to save VRAM for large images.
This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the
autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is
consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the
image.
Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is
active, as this will fail silently and run the operation without tiling.
```python
with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32):
latents = lda.tiled_image_to_latents(sample_image)
decoded_image = lda.tiled_latents_to_image(latents)
"""
try:
self._blending = blending
self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1])
self._add_fixed_group_norm(image, inference_size=self._tile_size)
yield
finally:
self._remove_fixed_group_norm()
self._tile_size = None
self._blending = None

View file

@ -16,7 +16,7 @@ def ref_path() -> Path:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def encoder(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder: def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
lda_weights = test_weights_path / "lda.safetensors" lda_weights = test_weights_path / "lda.safetensors"
if not lda_weights.is_file(): if not lda_weights.is_file():
warn(f"could not find weights at {lda_weights}, skipping") warn(f"could not find weights at {lda_weights}, skipping")
@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image:
@no_grad() @no_grad()
def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.image_to_latents(sample_image) encoded = lda.image_to_latents(sample_image)
decoded = encoder.latents_to_image(encoded) decoded = lda.latents_to_image(encoded)
assert decoded.mode == "RGB" # type: ignore assert decoded.mode == "RGB" # type: ignore
@ -52,9 +52,62 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image:
@no_grad() @no_grad()
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.images_to_latents([sample_image, sample_image]) encoded = lda.images_to_latents([sample_image, sample_image])
images = encoder.latents_to_images(encoded) images = lda.latents_to_images(encoded)
assert isinstance(images, list) assert isinstance(images, list)
assert len(images) == 2 assert len(images) == 2
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9) ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
@no_grad()
def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((2048, 2048)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 1024)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
@no_grad()
def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.resize((1024, 1024)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(2048, 2048)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
@no_grad()
def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
sample_image = sample_image.crop((0, 0, 300, 500))
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
encoded = lda.tiled_image_to_latents(sample_image)
result = lda.tiled_latents_to_image(encoded)
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
with pytest.raises(ValueError):
lda.tiled_image_to_latents(sample_image)
with pytest.raises(ValueError):
lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))