Add sample_noise staticmethod and modify add_noise to support batched steps

This commit is contained in:
Laurent 2024-04-18 09:42:31 +00:00 committed by Laureηt
parent 7427c171f6
commit 17246708b9
5 changed files with 122 additions and 8 deletions

View file

@ -35,32 +35,68 @@ class LatentDiffusionModel(fl.Module, ABC):
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
@staticmethod
def sample_noise(
size: tuple[int, ...],
device: Device | None = None,
dtype: DType | None = None,
offset_noise: float | None = None,
) -> torch.Tensor:
"""Sample noise from a normal distribution with an optional offset.
Args:
size: The size of the noise tensor.
device: The device to put the noise tensor on.
dtype: The data type of the noise tensor.
offset_noise: The offset of the noise tensor.
Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.
"""
noise = torch.randn(size=size, device=device, dtype=dtype)
if offset_noise is not None:
noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype)
return noise
def init_latents( def init_latents(
self, self,
size: tuple[int, int], size: tuple[int, int],
init_image: Image.Image | None = None, init_image: Image.Image | None = None,
noise: Tensor | None = None, noise: Tensor | None = None,
) -> Tensor: ) -> Tensor:
"""Initialize the latents for the diffusion process.
Args:
size: The size of the latent (in pixel space).
init_image: The image to use as initialization for the latents.
noise: The noise to add to the latents.
"""
height, width = size height, width = size
latent_height = height // 8
latent_width = width // 8
if noise is None: if noise is None:
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device) noise = LatentDiffusionModel.sample_noise(
size=(1, 4, latent_height, latent_width),
device=self.device,
dtype=self.dtype,
)
assert list(noise.shape[2:]) == [ assert list(noise.shape[2:]) == [
height // 8, latent_height,
width // 8, latent_width,
], f"noise shape is not compatible: {noise.shape}, with size: {size}" ], f"noise shape is not compatible: {noise.shape}, with size: {size}"
if init_image is None: if init_image is None:
x = noise latent = noise
else: else:
resized = init_image.resize(size=(width, height)) # type: ignore resized = init_image.resize(size=(width, height)) # type: ignore
encoded_image = self.lda.image_to_latents(resized) encoded_image = self.lda.image_to_latents(resized)
x = self.solver.add_noise( latent = self.solver.add_noise(
x=encoded_image, x=encoded_image,
noise=noise, noise=noise,
step=self.solver.first_inference_step, step=self.solver.first_inference_step,
) )
return self.solver.scale_model_input(x, step=-1) return self.solver.scale_model_input(latent, step=-1)
@property @property
def steps(self) -> list[int]: def steps(self) -> list[int]:

View file

@ -4,7 +4,19 @@ from enum import Enum
from typing import TypeVar from typing import TypeVar
import numpy as np import numpy as np
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor from torch import (
Generator,
Tensor,
arange,
device as Device,
dtype as DType,
float32,
linspace,
log,
sqrt,
stack,
tensor,
)
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
@ -208,7 +220,7 @@ class Solver(fl.Module, ABC):
offset=self.params.timesteps_offset, offset=self.params.timesteps_offset,
) )
def add_noise( def _add_noise(
self, self,
x: Tensor, x: Tensor,
noise: Tensor, noise: Tensor,
@ -227,9 +239,43 @@ class Solver(fl.Module, ABC):
timestep = self.timesteps[step] timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep] cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep] noise_stds = self.noise_std[timestep]
# noisify the latents, arXiv:2006.11239 Eq. 4
noised_x = cumulative_scale_factors * x + noise_stds * noise noised_x = cumulative_scale_factors * x + noise_stds * noise
return noised_x return noised_x
def add_noise(
self,
x: Tensor,
noise: Tensor,
step: int | list[int],
) -> Tensor:
"""Add noise to the input tensor using the solver's parameters.
Args:
x: The input tensor to add noise to.
noise: The noise tensor to add to the input tensor.
step: The current step(s) of the diffusion process.
Returns:
The input tensor with added noise.
"""
if isinstance(step, list):
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
return stack(
tensors=[
self._add_noise(
x=x[i],
noise=noise[i],
step=step[i],
)
for i in range(x.shape[0])
],
dim=0,
)
return self._add_noise(x=x, noise=noise, step=step)
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process. """Remove noise from the input tensor using the current step of the diffusion process.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

After

Width:  |  Height:  |  Size: 1.4 MiB

View file

@ -0,0 +1,14 @@
import torch
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
@no_grad()
def test_sample_noise():
manual_seed(2)
latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64))
manual_seed(2)
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)

View file

@ -198,6 +198,24 @@ def test_solver_device(test_device: Device):
assert noised.device == test_device assert noised.device == test_device
def test_solver_add_noise(test_device: Device):
scheduler = DDIM(num_inference_steps=30, device=test_device)
latent = randn(1, 4, 32, 32, device=test_device)
noise = randn(1, 4, 32, 32, device=test_device)
noised = scheduler.add_noise(
x=latent,
noise=noise,
step=0,
)
noised_double = scheduler.add_noise(
x=latent.repeat(2, 1, 1, 1),
noise=noise.repeat(2, 1, 1, 1),
step=[0, 0],
)
assert allclose(noised, noised_double[0])
assert allclose(noised, noised_double[1])
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) @pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
scheduler = DDIM( scheduler = DDIM(