From b42881e54e24b6302f941e623fa641203ecf7e6d Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Mon, 24 Jun 2024 16:35:43 +0000 Subject: [PATCH] Implement Tiled Autoencoder inference to save VRAM --- .../latent_diffusion/auto_encoder.py | 342 +++++++++++++++++- .../latent_diffusion/test_auto_encoder.py | 67 +++- 2 files changed, 401 insertions(+), 8 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/auto_encoder.py b/src/refiners/foundationals/latent_diffusion/auto_encoder.py index f17c783..0d1794d 100644 --- a/src/refiners/foundationals/latent_diffusion/auto_encoder.py +++ b/src/refiners/foundationals/latent_diffusion/auto_encoder.py @@ -1,6 +1,13 @@ +from contextlib import contextmanager +from typing import Generator, NamedTuple + +import torch from PIL import Image from torch import Tensor, device as Device, dtype as DType +from torch.nn import functional as F +from refiners.fluxion import layers as fl +from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.context import Contexts from refiners.fluxion.layers import ( Chain, @@ -15,7 +22,19 @@ from refiners.fluxion.layers import ( Sum, Upsample, ) -from refiners.fluxion.utils import images_to_tensor, tensor_to_images +from refiners.fluxion.utils import image_to_tensor, images_to_tensor, no_grad, tensor_to_image, tensor_to_images + + +class _ImageSize(NamedTuple): + height: int + width: int + + +class _Tile(NamedTuple): + top: int + left: int + bottom: int + right: int class Resnet(Sum): @@ -187,6 +206,78 @@ class Decoder(Chain): ) +class FixedGroupNorm(fl.Chain, Adapter[fl.GroupNorm]): + """ + Adapter for GroupNorm layers to fix the running mean and variance. + + This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers + are consistent across tiles. + """ + + mean: torch.Tensor | None + var: torch.Tensor | None + + def __init__(self, target: fl.GroupNorm) -> None: + self.mean = None + self.var = None + with self.setup_adapter(target): + super().__init__(fl.Lambda(self.compute_group_norm)) + + def compute_group_norm(self, x: torch.Tensor) -> torch.Tensor: + batch, channels, height, width = x.shape + # Reshape the tensor to apply batch norm to each group separately (to mimic group norm behavior) + x = x.reshape( + 1, + batch * self.target.num_groups, + int(channels / self.target.num_groups), + height, + width, + ) + + if self.mean is None or self.var is None: + self.var, self.mean = torch.var_mean(x, dim=(0, 2, 3, 4), correction=0) + + result = F.batch_norm( + input=x, + running_mean=self.mean, + running_var=self.var, + weight=None, + bias=None, + training=False, + momentum=0, + eps=self.target.eps, + ) + result = result.reshape(batch, channels, height, width) + return result * self.target.weight.reshape(1, -1, 1, 1) + self.target.bias.reshape(1, -1, 1, 1) + + +def _create_blending_mask( + size: _ImageSize, + blending: int, + num_channels: int, + device: torch.device | None = None, + is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False), +) -> torch.Tensor: + mask = torch.ones(size, device=device) + if blending == 0: + return mask + blending = min(blending, min(size) // 2) + + ramp = torch.linspace(0, 1, steps=blending, device=device) + + # Apply ramps only if not at the corresponding edge + if not is_edge[0]: # top + mask[:blending, :] *= ramp.view(-1, 1) + if not is_edge[1]: # bottom + mask[-blending:, :] *= ramp.flip(0).view(-1, 1) + if not is_edge[2]: # left + mask[:, :blending] *= ramp.view(1, -1) + if not is_edge[3]: # right + mask[:, -blending:] *= ramp.flip(0).view(1, -1) + + return mask.unsqueeze(0).unsqueeze(0).expand(1, num_channels, *size) + + class LatentDiffusionAutoencoder(Chain): """Latent diffusion autoencoder model. @@ -211,6 +302,8 @@ class LatentDiffusionAutoencoder(Chain): Encoder(device=device, dtype=dtype), Decoder(device=device, dtype=dtype), ) + self._tile_size = None + self._blending = None def encode(self, x: Tensor) -> Tensor: """Encode an image. @@ -239,8 +332,29 @@ class LatentDiffusionAutoencoder(Chain): return x def image_to_latents(self, image: Image.Image) -> Tensor: + """ + Encode an image to latents. + """ return self.images_to_latents([image]) + def tiled_image_to_latents(self, image: Image.Image) -> Tensor: + """ + Convert an image to latents with gradient blending to smooth tile edges. + + You need to activate the tiled inference context manager with the `tiled_inference` method to use this method. + + ```python + with lda.tiled_inference(sample_image, tile_size=(768, 1024)): + latents = lda.tiled_image_to_latents(sample_image) + """ + if self._tile_size is None: + raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.") + + assert self._tile_size is not None and self._blending is not None + image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype) + image_tensor = 2 * image_tensor - 1 + return self._tiled_encode(image_tensor, self._tile_size, self._blending) + def images_to_latents(self, images: list[Image.Image]) -> Tensor: """Convert a list of images to latents. @@ -260,11 +374,31 @@ class LatentDiffusionAutoencoder(Chain): return self.latents_to_image(x) def latents_to_image(self, x: Tensor) -> Image.Image: + """ + Decode latents to an image. + """ if x.shape[0] != 1: raise ValueError(f"Expected batch size of 1, got {x.shape[0]}") return self.latents_to_images(x)[0] + def tiled_latents_to_image(self, x: Tensor) -> Image.Image: + """ + Convert latents to an image with gradient blending to smooth tile edges. + + You need to activate the tiled inference context manager with the `tiled_inference` method to use this method. + + ```python + with lda.tiled_inference(sample_image, tile_size=(768, 1024)): + image = lda.tiled_latents_to_image(latents) + """ + if self._tile_size is None: + raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.") + + assert self._tile_size is not None and self._blending is not None + result = self._tiled_decode(x, self._tile_size, self._blending) + return tensor_to_image((result + 1) / 2) + def latents_to_images(self, x: Tensor) -> list[Image.Image]: """Convert a tensor of latents to images. @@ -277,3 +411,209 @@ class LatentDiffusionAutoencoder(Chain): x = self.decode(x) x = (x + 1) / 2 return tensor_to_images(x) + + @staticmethod + def _generate_latent_tiles(size: _ImageSize, tile_size: _ImageSize, overlap: int = 8) -> list[_Tile]: + """ + Generate tiles for a given latent size and tile size with a given overlap. + """ + tiles: list[_Tile] = [] + + for y in range(0, size.width, tile_size.width - overlap): + for x in range(0, size.height, tile_size.height - overlap): + tile = _Tile( + top=max(0, x), + left=max(0, y), + bottom=min(size.height, x + tile_size.height), + right=min(size.width, y + tile_size.width), + ) + tiles.append(tile) + + return tiles + + @no_grad() + def _add_fixed_group_norm(self, image: Image.Image, inference_size: _ImageSize) -> None: + """ + Set the running mean and variance of the group norm layers in the latent diffusion autoencoder. + + We replace the GroupNorm layers with FixedGroupNorm layers that will compute the group norm statistics on its + first forward pass and then fix them for all subsequent passes. This is useful when running tiled inference to + ensure that the statistics of the GroupNorm layers are consistent across tiles. + """ + for group_norm, parent in self.walk(fl.GroupNorm): + FixedGroupNorm(group_norm).inject(parent) + + downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore + + image_tensor = image_to_tensor(image, device=self.device) + downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device) + downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max()) + + std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True) + new_std, new_mean = torch.std_mean(downscaled_image_tensor, dim=[0, 2, 3], keepdim=True) + + downscaled_image_tensor = (downscaled_image_tensor - new_mean) * (std / new_std) + mean + downscaled_image_tensor = 2 * downscaled_image_tensor - 1 + + # We do a forward pass through the encoder and decoder to set the group norm stats in the FixedGroupNorm layers + latents = self.encode(downscaled_image_tensor) + self.decode(latents) + + def _remove_fixed_group_norm(self) -> None: + """ + Remove the FixedGroupNorm layers and restore the original GroupNorm layers. + """ + for fixed_group_norm in self.layers(FixedGroupNorm): + fixed_group_norm.eject() + + @no_grad() + def _tiled_encode(self, image_tensor: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor: + """ + Encode an image to latents with tile-based inference and gradient blending to smooth tile edges. + + If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the + default tile size of 512x512 is used. + """ + latent_size = _ImageSize(height=image_tensor.shape[2] // 8, width=image_tensor.shape[3] // 8) + target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8) + tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8) + + if len(tiles) == 1: + return self.encode(image_tensor) + + result = torch.zeros((1, 4, *latent_size), device=self.device) + weights = torch.zeros_like(result) + + for latent_tile in tiles: + pixel_tile = image_tensor[ + :, + :, + latent_tile.top * 8 : latent_tile.bottom * 8, + latent_tile.left * 8 : latent_tile.right * 8, + ] + encoded_tile = self.encode(pixel_tile) + + is_edge = ( + latent_tile.top == 0, + latent_tile.bottom == latent_size.height, + latent_tile.left == 0, + latent_tile.right == latent_size.width, + ) + + latent_tile_size = _ImageSize( + height=(latent_tile.bottom - latent_tile.top), width=(latent_tile.right - latent_tile.left) + ) + + tile_mask = _create_blending_mask( + latent_tile_size, + blending // 8, + num_channels=4, + device=self.device, + is_edge=is_edge, + ) + + result[ + :, + :, + latent_tile.top : latent_tile.bottom, + latent_tile.left : latent_tile.right, + ] += encoded_tile * tile_mask + + weights[ + :, + :, + latent_tile.top : latent_tile.bottom, + latent_tile.left : latent_tile.right, + ] += tile_mask + + return result / weights + + @no_grad() + def _tiled_decode(self, latents: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor: + """ + Convert latents to an image for the given latent diffusion autoencoder, with gradient blending to smooth tile edges. + + If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the + default tile size of 512x512 is used. + """ + latent_size = _ImageSize(height=latents.shape[2], width=latents.shape[3]) + pixel_size = _ImageSize(height=latent_size.height * 8, width=latent_size.width * 8) + target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8) + tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8) + if len(tiles) == 1: + return self.decode(latents) + + result = torch.zeros((1, 3, *pixel_size), device=self.device) + weights = torch.zeros_like(result) + + for latent_tile in tiles: + pixel_offset = _ImageSize(height=latent_tile.top * 8, width=latent_tile.left * 8) + latent_tile_size = _ImageSize( + height=latent_tile.bottom - latent_tile.top, width=latent_tile.right - latent_tile.left + ) + pixel_tile_size = _ImageSize(height=latent_tile_size.height * 8, width=latent_tile_size.width * 8) + + pixel_tile = self.decode( + latents[ + :, + :, + latent_tile.top : latent_tile.bottom, + latent_tile.left : latent_tile.right, + ] + ) + + is_edge = ( + latent_tile.top == 0, + latent_tile.bottom == latent_size.height, + latent_tile.left == 0, + latent_tile.right == latent_size.width, + ) + + pixel_tile_mask = _create_blending_mask( + pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge + ) + result[ + :, + :, + pixel_offset.height : pixel_offset.height + pixel_tile_size.height, + pixel_offset.width : pixel_offset.width + pixel_tile_size.width, + ] += pixel_tile * pixel_tile_mask + + weights[ + :, + :, + pixel_offset.height : pixel_offset.height + pixel_tile_size.height, + pixel_offset.width : pixel_offset.width + pixel_tile_size.width, + ] += pixel_tile_mask + + return result / weights + + @contextmanager + def tiled_inference( + self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64 + ) -> Generator[None, None, None]: + """ + Context manager for tiled inference operations to save VRAM for large images. + + This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the + autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is + consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the + image. + + Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is + active, as this will fail silently and run the operation without tiling. + + ```python + with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32): + latents = lda.tiled_image_to_latents(sample_image) + decoded_image = lda.tiled_latents_to_image(latents) + """ + try: + self._blending = blending + self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1]) + self._add_fixed_group_norm(image, inference_size=self._tile_size) + yield + finally: + self._remove_fixed_group_norm() + self._tile_size = None + self._blending = None diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index d140a24..dc6d77c 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -16,7 +16,7 @@ def ref_path() -> Path: @pytest.fixture(scope="module") -def encoder(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder: +def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder: lda_weights = test_weights_path / "lda.safetensors" if not lda_weights.is_file(): warn(f"could not find weights at {lda_weights}, skipping") @@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image: @no_grad() -def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): - encoded = encoder.image_to_latents(sample_image) - decoded = encoder.latents_to_image(encoded) +def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): + encoded = lda.image_to_latents(sample_image) + decoded = lda.latents_to_image(encoded) assert decoded.mode == "RGB" # type: ignore @@ -52,9 +52,62 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: @no_grad() -def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): - encoded = encoder.images_to_latents([sample_image, sample_image]) - images = encoder.latents_to_images(encoded) +def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): + encoded = lda.images_to_latents([sample_image, sample_image]) + images = lda.latents_to_images(encoded) assert isinstance(images, list) assert len(images) == 2 ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9) + + +@no_grad() +def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.resize((2048, 2048)) # type: ignore + + with lda.tiled_inference(sample_image, tile_size=(512, 512)): + encoded = lda.tiled_image_to_latents(sample_image) + result = lda.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985) + + +@no_grad() +def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.resize((2048, 2048)) # type: ignore + + with lda.tiled_inference(sample_image, tile_size=(512, 1024)): + encoded = lda.tiled_image_to_latents(sample_image) + result = lda.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985) + + +@no_grad() +def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.resize((1024, 1024)) # type: ignore + + with lda.tiled_inference(sample_image, tile_size=(2048, 2048)): + encoded = lda.tiled_image_to_latents(sample_image) + result = lda.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975) + + +@no_grad() +def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image): + sample_image = sample_image.crop((0, 0, 300, 500)) + sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore + + with lda.tiled_inference(sample_image, tile_size=(512, 512)): + encoded = lda.tiled_image_to_latents(sample_image) + result = lda.tiled_latents_to_image(encoded) + + ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985) + + +def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None: + with pytest.raises(ValueError): + lda.tiled_image_to_latents(sample_image) + + with pytest.raises(ValueError): + lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))