mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Implement Tiled Autoencoder inference to save VRAM
This commit is contained in:
parent
b16cb7fbb6
commit
b42881e54e
|
@ -1,6 +1,13 @@
|
|||
from contextlib import contextmanager
|
||||
from typing import Generator, NamedTuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
from torch.nn import functional as F
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
from refiners.fluxion.context import Contexts
|
||||
from refiners.fluxion.layers import (
|
||||
Chain,
|
||||
|
@ -15,7 +22,19 @@ from refiners.fluxion.layers import (
|
|||
Sum,
|
||||
Upsample,
|
||||
)
|
||||
from refiners.fluxion.utils import images_to_tensor, tensor_to_images
|
||||
from refiners.fluxion.utils import image_to_tensor, images_to_tensor, no_grad, tensor_to_image, tensor_to_images
|
||||
|
||||
|
||||
class _ImageSize(NamedTuple):
|
||||
height: int
|
||||
width: int
|
||||
|
||||
|
||||
class _Tile(NamedTuple):
|
||||
top: int
|
||||
left: int
|
||||
bottom: int
|
||||
right: int
|
||||
|
||||
|
||||
class Resnet(Sum):
|
||||
|
@ -187,6 +206,78 @@ class Decoder(Chain):
|
|||
)
|
||||
|
||||
|
||||
class FixedGroupNorm(fl.Chain, Adapter[fl.GroupNorm]):
|
||||
"""
|
||||
Adapter for GroupNorm layers to fix the running mean and variance.
|
||||
|
||||
This is useful when running tiled inference with a autoencoder to ensure that the statistics of the GroupNorm layers
|
||||
are consistent across tiles.
|
||||
"""
|
||||
|
||||
mean: torch.Tensor | None
|
||||
var: torch.Tensor | None
|
||||
|
||||
def __init__(self, target: fl.GroupNorm) -> None:
|
||||
self.mean = None
|
||||
self.var = None
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(fl.Lambda(self.compute_group_norm))
|
||||
|
||||
def compute_group_norm(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, height, width = x.shape
|
||||
# Reshape the tensor to apply batch norm to each group separately (to mimic group norm behavior)
|
||||
x = x.reshape(
|
||||
1,
|
||||
batch * self.target.num_groups,
|
||||
int(channels / self.target.num_groups),
|
||||
height,
|
||||
width,
|
||||
)
|
||||
|
||||
if self.mean is None or self.var is None:
|
||||
self.var, self.mean = torch.var_mean(x, dim=(0, 2, 3, 4), correction=0)
|
||||
|
||||
result = F.batch_norm(
|
||||
input=x,
|
||||
running_mean=self.mean,
|
||||
running_var=self.var,
|
||||
weight=None,
|
||||
bias=None,
|
||||
training=False,
|
||||
momentum=0,
|
||||
eps=self.target.eps,
|
||||
)
|
||||
result = result.reshape(batch, channels, height, width)
|
||||
return result * self.target.weight.reshape(1, -1, 1, 1) + self.target.bias.reshape(1, -1, 1, 1)
|
||||
|
||||
|
||||
def _create_blending_mask(
|
||||
size: _ImageSize,
|
||||
blending: int,
|
||||
num_channels: int,
|
||||
device: torch.device | None = None,
|
||||
is_edge: tuple[bool, bool, bool, bool] = (False, False, False, False),
|
||||
) -> torch.Tensor:
|
||||
mask = torch.ones(size, device=device)
|
||||
if blending == 0:
|
||||
return mask
|
||||
blending = min(blending, min(size) // 2)
|
||||
|
||||
ramp = torch.linspace(0, 1, steps=blending, device=device)
|
||||
|
||||
# Apply ramps only if not at the corresponding edge
|
||||
if not is_edge[0]: # top
|
||||
mask[:blending, :] *= ramp.view(-1, 1)
|
||||
if not is_edge[1]: # bottom
|
||||
mask[-blending:, :] *= ramp.flip(0).view(-1, 1)
|
||||
if not is_edge[2]: # left
|
||||
mask[:, :blending] *= ramp.view(1, -1)
|
||||
if not is_edge[3]: # right
|
||||
mask[:, -blending:] *= ramp.flip(0).view(1, -1)
|
||||
|
||||
return mask.unsqueeze(0).unsqueeze(0).expand(1, num_channels, *size)
|
||||
|
||||
|
||||
class LatentDiffusionAutoencoder(Chain):
|
||||
"""Latent diffusion autoencoder model.
|
||||
|
||||
|
@ -211,6 +302,8 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
Encoder(device=device, dtype=dtype),
|
||||
Decoder(device=device, dtype=dtype),
|
||||
)
|
||||
self._tile_size = None
|
||||
self._blending = None
|
||||
|
||||
def encode(self, x: Tensor) -> Tensor:
|
||||
"""Encode an image.
|
||||
|
@ -239,8 +332,29 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
return x
|
||||
|
||||
def image_to_latents(self, image: Image.Image) -> Tensor:
|
||||
"""
|
||||
Encode an image to latents.
|
||||
"""
|
||||
return self.images_to_latents([image])
|
||||
|
||||
def tiled_image_to_latents(self, image: Image.Image) -> Tensor:
|
||||
"""
|
||||
Convert an image to latents with gradient blending to smooth tile edges.
|
||||
|
||||
You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.
|
||||
|
||||
```python
|
||||
with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
|
||||
latents = lda.tiled_image_to_latents(sample_image)
|
||||
"""
|
||||
if self._tile_size is None:
|
||||
raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")
|
||||
|
||||
assert self._tile_size is not None and self._blending is not None
|
||||
image_tensor = image_to_tensor(image, device=self.device, dtype=self.dtype)
|
||||
image_tensor = 2 * image_tensor - 1
|
||||
return self._tiled_encode(image_tensor, self._tile_size, self._blending)
|
||||
|
||||
def images_to_latents(self, images: list[Image.Image]) -> Tensor:
|
||||
"""Convert a list of images to latents.
|
||||
|
||||
|
@ -260,11 +374,31 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
return self.latents_to_image(x)
|
||||
|
||||
def latents_to_image(self, x: Tensor) -> Image.Image:
|
||||
"""
|
||||
Decode latents to an image.
|
||||
"""
|
||||
if x.shape[0] != 1:
|
||||
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")
|
||||
|
||||
return self.latents_to_images(x)[0]
|
||||
|
||||
def tiled_latents_to_image(self, x: Tensor) -> Image.Image:
|
||||
"""
|
||||
Convert latents to an image with gradient blending to smooth tile edges.
|
||||
|
||||
You need to activate the tiled inference context manager with the `tiled_inference` method to use this method.
|
||||
|
||||
```python
|
||||
with lda.tiled_inference(sample_image, tile_size=(768, 1024)):
|
||||
image = lda.tiled_latents_to_image(latents)
|
||||
"""
|
||||
if self._tile_size is None:
|
||||
raise ValueError("Tiled inference context manager not active. Use `tiled_inference` method to activate.")
|
||||
|
||||
assert self._tile_size is not None and self._blending is not None
|
||||
result = self._tiled_decode(x, self._tile_size, self._blending)
|
||||
return tensor_to_image((result + 1) / 2)
|
||||
|
||||
def latents_to_images(self, x: Tensor) -> list[Image.Image]:
|
||||
"""Convert a tensor of latents to images.
|
||||
|
||||
|
@ -277,3 +411,209 @@ class LatentDiffusionAutoencoder(Chain):
|
|||
x = self.decode(x)
|
||||
x = (x + 1) / 2
|
||||
return tensor_to_images(x)
|
||||
|
||||
@staticmethod
|
||||
def _generate_latent_tiles(size: _ImageSize, tile_size: _ImageSize, overlap: int = 8) -> list[_Tile]:
|
||||
"""
|
||||
Generate tiles for a given latent size and tile size with a given overlap.
|
||||
"""
|
||||
tiles: list[_Tile] = []
|
||||
|
||||
for y in range(0, size.width, tile_size.width - overlap):
|
||||
for x in range(0, size.height, tile_size.height - overlap):
|
||||
tile = _Tile(
|
||||
top=max(0, x),
|
||||
left=max(0, y),
|
||||
bottom=min(size.height, x + tile_size.height),
|
||||
right=min(size.width, y + tile_size.width),
|
||||
)
|
||||
tiles.append(tile)
|
||||
|
||||
return tiles
|
||||
|
||||
@no_grad()
|
||||
def _add_fixed_group_norm(self, image: Image.Image, inference_size: _ImageSize) -> None:
|
||||
"""
|
||||
Set the running mean and variance of the group norm layers in the latent diffusion autoencoder.
|
||||
|
||||
We replace the GroupNorm layers with FixedGroupNorm layers that will compute the group norm statistics on its
|
||||
first forward pass and then fix them for all subsequent passes. This is useful when running tiled inference to
|
||||
ensure that the statistics of the GroupNorm layers are consistent across tiles.
|
||||
"""
|
||||
for group_norm, parent in self.walk(fl.GroupNorm):
|
||||
FixedGroupNorm(group_norm).inject(parent)
|
||||
|
||||
downscaled_image = image.resize((inference_size.width, inference_size.height)) # type: ignore
|
||||
|
||||
image_tensor = image_to_tensor(image, device=self.device)
|
||||
downscaled_image_tensor = image_to_tensor(downscaled_image, device=self.device)
|
||||
downscaled_image_tensor.clamp_(min=image_tensor.min(), max=image_tensor.max())
|
||||
|
||||
std, mean = torch.std_mean(image_tensor, dim=[0, 2, 3], keepdim=True)
|
||||
new_std, new_mean = torch.std_mean(downscaled_image_tensor, dim=[0, 2, 3], keepdim=True)
|
||||
|
||||
downscaled_image_tensor = (downscaled_image_tensor - new_mean) * (std / new_std) + mean
|
||||
downscaled_image_tensor = 2 * downscaled_image_tensor - 1
|
||||
|
||||
# We do a forward pass through the encoder and decoder to set the group norm stats in the FixedGroupNorm layers
|
||||
latents = self.encode(downscaled_image_tensor)
|
||||
self.decode(latents)
|
||||
|
||||
def _remove_fixed_group_norm(self) -> None:
|
||||
"""
|
||||
Remove the FixedGroupNorm layers and restore the original GroupNorm layers.
|
||||
"""
|
||||
for fixed_group_norm in self.layers(FixedGroupNorm):
|
||||
fixed_group_norm.eject()
|
||||
|
||||
@no_grad()
|
||||
def _tiled_encode(self, image_tensor: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Encode an image to latents with tile-based inference and gradient blending to smooth tile edges.
|
||||
|
||||
If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the
|
||||
default tile size of 512x512 is used.
|
||||
"""
|
||||
latent_size = _ImageSize(height=image_tensor.shape[2] // 8, width=image_tensor.shape[3] // 8)
|
||||
target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8)
|
||||
tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8)
|
||||
|
||||
if len(tiles) == 1:
|
||||
return self.encode(image_tensor)
|
||||
|
||||
result = torch.zeros((1, 4, *latent_size), device=self.device)
|
||||
weights = torch.zeros_like(result)
|
||||
|
||||
for latent_tile in tiles:
|
||||
pixel_tile = image_tensor[
|
||||
:,
|
||||
:,
|
||||
latent_tile.top * 8 : latent_tile.bottom * 8,
|
||||
latent_tile.left * 8 : latent_tile.right * 8,
|
||||
]
|
||||
encoded_tile = self.encode(pixel_tile)
|
||||
|
||||
is_edge = (
|
||||
latent_tile.top == 0,
|
||||
latent_tile.bottom == latent_size.height,
|
||||
latent_tile.left == 0,
|
||||
latent_tile.right == latent_size.width,
|
||||
)
|
||||
|
||||
latent_tile_size = _ImageSize(
|
||||
height=(latent_tile.bottom - latent_tile.top), width=(latent_tile.right - latent_tile.left)
|
||||
)
|
||||
|
||||
tile_mask = _create_blending_mask(
|
||||
latent_tile_size,
|
||||
blending // 8,
|
||||
num_channels=4,
|
||||
device=self.device,
|
||||
is_edge=is_edge,
|
||||
)
|
||||
|
||||
result[
|
||||
:,
|
||||
:,
|
||||
latent_tile.top : latent_tile.bottom,
|
||||
latent_tile.left : latent_tile.right,
|
||||
] += encoded_tile * tile_mask
|
||||
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
latent_tile.top : latent_tile.bottom,
|
||||
latent_tile.left : latent_tile.right,
|
||||
] += tile_mask
|
||||
|
||||
return result / weights
|
||||
|
||||
@no_grad()
|
||||
def _tiled_decode(self, latents: torch.Tensor, tile_size: _ImageSize, blending: int = 64) -> torch.Tensor:
|
||||
"""
|
||||
Convert latents to an image for the given latent diffusion autoencoder, with gradient blending to smooth tile edges.
|
||||
|
||||
If `tile_size` is not provided, the tile size provided in the `tiled_inference` context manager is used, or the
|
||||
default tile size of 512x512 is used.
|
||||
"""
|
||||
latent_size = _ImageSize(height=latents.shape[2], width=latents.shape[3])
|
||||
pixel_size = _ImageSize(height=latent_size.height * 8, width=latent_size.width * 8)
|
||||
target_latent_tile_size = _ImageSize(height=tile_size.height // 8, width=tile_size.width // 8)
|
||||
tiles = self._generate_latent_tiles(latent_size, tile_size=target_latent_tile_size, overlap=blending // 8)
|
||||
if len(tiles) == 1:
|
||||
return self.decode(latents)
|
||||
|
||||
result = torch.zeros((1, 3, *pixel_size), device=self.device)
|
||||
weights = torch.zeros_like(result)
|
||||
|
||||
for latent_tile in tiles:
|
||||
pixel_offset = _ImageSize(height=latent_tile.top * 8, width=latent_tile.left * 8)
|
||||
latent_tile_size = _ImageSize(
|
||||
height=latent_tile.bottom - latent_tile.top, width=latent_tile.right - latent_tile.left
|
||||
)
|
||||
pixel_tile_size = _ImageSize(height=latent_tile_size.height * 8, width=latent_tile_size.width * 8)
|
||||
|
||||
pixel_tile = self.decode(
|
||||
latents[
|
||||
:,
|
||||
:,
|
||||
latent_tile.top : latent_tile.bottom,
|
||||
latent_tile.left : latent_tile.right,
|
||||
]
|
||||
)
|
||||
|
||||
is_edge = (
|
||||
latent_tile.top == 0,
|
||||
latent_tile.bottom == latent_size.height,
|
||||
latent_tile.left == 0,
|
||||
latent_tile.right == latent_size.width,
|
||||
)
|
||||
|
||||
pixel_tile_mask = _create_blending_mask(
|
||||
pixel_tile_size, blending, num_channels=3, device=self.device, is_edge=is_edge
|
||||
)
|
||||
result[
|
||||
:,
|
||||
:,
|
||||
pixel_offset.height : pixel_offset.height + pixel_tile_size.height,
|
||||
pixel_offset.width : pixel_offset.width + pixel_tile_size.width,
|
||||
] += pixel_tile * pixel_tile_mask
|
||||
|
||||
weights[
|
||||
:,
|
||||
:,
|
||||
pixel_offset.height : pixel_offset.height + pixel_tile_size.height,
|
||||
pixel_offset.width : pixel_offset.width + pixel_tile_size.width,
|
||||
] += pixel_tile_mask
|
||||
|
||||
return result / weights
|
||||
|
||||
@contextmanager
|
||||
def tiled_inference(
|
||||
self, image: Image.Image, tile_size: tuple[int, int] = (512, 512), blending: int = 64
|
||||
) -> Generator[None, None, None]:
|
||||
"""
|
||||
Context manager for tiled inference operations to save VRAM for large images.
|
||||
|
||||
This context manager sets up a consistent GroupNorm statistics for performing tiled operations on the
|
||||
autoencoder, including setting and resetting group norm statistics. This allow to make sure that the result is
|
||||
consistent across tiles by capturing the statistics of the GroupNorm layers on a downsampled version of the
|
||||
image.
|
||||
|
||||
Be careful not to use the normal `image_to_latents` and `latents_to_image` methods while this context manager is
|
||||
active, as this will fail silently and run the operation without tiling.
|
||||
|
||||
```python
|
||||
with lda.tiled_inference(sample_image, tile_size=(768, 1024), blending=32):
|
||||
latents = lda.tiled_image_to_latents(sample_image)
|
||||
decoded_image = lda.tiled_latents_to_image(latents)
|
||||
"""
|
||||
try:
|
||||
self._blending = blending
|
||||
self._tile_size = _ImageSize(width=tile_size[0], height=tile_size[1])
|
||||
self._add_fixed_group_norm(image, inference_size=self._tile_size)
|
||||
yield
|
||||
finally:
|
||||
self._remove_fixed_group_norm()
|
||||
self._tile_size = None
|
||||
self._blending = None
|
||||
|
|
|
@ -16,7 +16,7 @@ def ref_path() -> Path:
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def encoder(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
|
||||
def lda(test_weights_path: Path, test_device: torch.device) -> LatentDiffusionAutoencoder:
|
||||
lda_weights = test_weights_path / "lda.safetensors"
|
||||
if not lda_weights.is_file():
|
||||
warn(f"could not find weights at {lda_weights}, skipping")
|
||||
|
@ -39,9 +39,9 @@ def sample_image(ref_path: Path) -> Image.Image:
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = encoder.image_to_latents(sample_image)
|
||||
decoded = encoder.latents_to_image(encoded)
|
||||
def test_encode_decode_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = lda.image_to_latents(sample_image)
|
||||
decoded = lda.latents_to_image(encoded)
|
||||
|
||||
assert decoded.mode == "RGB" # type: ignore
|
||||
|
||||
|
@ -52,9 +52,62 @@ def test_encode_decode_image(encoder: LatentDiffusionAutoencoder, sample_image:
|
|||
|
||||
|
||||
@no_grad()
|
||||
def test_encode_decode_images(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = encoder.images_to_latents([sample_image, sample_image])
|
||||
images = encoder.latents_to_images(encoded)
|
||||
def test_encode_decode_images(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = lda.images_to_latents([sample_image, sample_image])
|
||||
images = lda.latents_to_images(encoded)
|
||||
assert isinstance(images, list)
|
||||
assert len(images) == 2
|
||||
ensure_similar_images(sample_image, images[1], min_psnr=20, min_ssim=0.9)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_tiled_autoencoder(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
||||
|
||||
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
|
||||
encoded = lda.tiled_image_to_latents(sample_image)
|
||||
result = lda.tiled_latents_to_image(encoded)
|
||||
|
||||
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_tiled_autoencoder_rectangular_tiles(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
sample_image = sample_image.resize((2048, 2048)) # type: ignore
|
||||
|
||||
with lda.tiled_inference(sample_image, tile_size=(512, 1024)):
|
||||
encoded = lda.tiled_image_to_latents(sample_image)
|
||||
result = lda.tiled_latents_to_image(encoded)
|
||||
|
||||
ensure_similar_images(sample_image, result, min_psnr=35, min_ssim=0.985)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_tiled_autoencoder_large_tile(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
sample_image = sample_image.resize((1024, 1024)) # type: ignore
|
||||
|
||||
with lda.tiled_inference(sample_image, tile_size=(2048, 2048)):
|
||||
encoded = lda.tiled_image_to_latents(sample_image)
|
||||
result = lda.tiled_latents_to_image(encoded)
|
||||
|
||||
ensure_similar_images(sample_image, result, min_psnr=34, min_ssim=0.975)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_tiled_autoencoder_rectangular_image(lda: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
sample_image = sample_image.crop((0, 0, 300, 500))
|
||||
sample_image = sample_image.resize((sample_image.width * 4, sample_image.height * 4)) # type: ignore
|
||||
|
||||
with lda.tiled_inference(sample_image, tile_size=(512, 512)):
|
||||
encoded = lda.tiled_image_to_latents(sample_image)
|
||||
result = lda.tiled_latents_to_image(encoded)
|
||||
|
||||
ensure_similar_images(sample_image, result, min_psnr=37, min_ssim=0.985)
|
||||
|
||||
|
||||
def test_value_error_tile_encode_no_context(lda: LatentDiffusionAutoencoder, sample_image: Image.Image) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
lda.tiled_image_to_latents(sample_image)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
lda.tiled_latents_to_image(torch.randn(1, 8, 16, 16, device=lda.device))
|
||||
|
|
Loading…
Reference in a new issue