diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index aea4bcd..4bb51e8 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -35,32 +35,68 @@ class LatentDiffusionModel(fl.Module, ABC): def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None: self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step) + @staticmethod + def sample_noise( + size: tuple[int, ...], + device: Device | None = None, + dtype: DType | None = None, + offset_noise: float | None = None, + ) -> torch.Tensor: + """Sample noise from a normal distribution with an optional offset. + + Args: + size: The size of the noise tensor. + device: The device to put the noise tensor on. + dtype: The data type of the noise tensor. + offset_noise: The offset of the noise tensor. + Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise. + """ + noise = torch.randn(size=size, device=device, dtype=dtype) + if offset_noise is not None: + noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype) + return noise + def init_latents( self, size: tuple[int, int], init_image: Image.Image | None = None, noise: Tensor | None = None, ) -> Tensor: + """Initialize the latents for the diffusion process. + + Args: + size: The size of the latent (in pixel space). + init_image: The image to use as initialization for the latents. + noise: The noise to add to the latents. + """ height, width = size + latent_height = height // 8 + latent_width = width // 8 + if noise is None: - noise = torch.randn(1, 4, height // 8, width // 8, device=self.device) + noise = LatentDiffusionModel.sample_noise( + size=(1, 4, latent_height, latent_width), + device=self.device, + dtype=self.dtype, + ) + assert list(noise.shape[2:]) == [ - height // 8, - width // 8, + latent_height, + latent_width, ], f"noise shape is not compatible: {noise.shape}, with size: {size}" if init_image is None: - x = noise + latent = noise else: resized = init_image.resize(size=(width, height)) # type: ignore encoded_image = self.lda.image_to_latents(resized) - x = self.solver.add_noise( + latent = self.solver.add_noise( x=encoded_image, noise=noise, step=self.solver.first_inference_step, ) - return self.solver.scale_model_input(x, step=-1) + return self.solver.scale_model_input(latent, step=-1) @property def steps(self) -> list[int]: diff --git a/src/refiners/foundationals/latent_diffusion/solvers/solver.py b/src/refiners/foundationals/latent_diffusion/solvers/solver.py index caaaaec..24c570b 100644 --- a/src/refiners/foundationals/latent_diffusion/solvers/solver.py +++ b/src/refiners/foundationals/latent_diffusion/solvers/solver.py @@ -4,7 +4,19 @@ from enum import Enum from typing import TypeVar import numpy as np -from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor +from torch import ( + Generator, + Tensor, + arange, + device as Device, + dtype as DType, + float32, + linspace, + log, + sqrt, + stack, + tensor, +) from refiners.fluxion import layers as fl @@ -208,7 +220,7 @@ class Solver(fl.Module, ABC): offset=self.params.timesteps_offset, ) - def add_noise( + def _add_noise( self, x: Tensor, noise: Tensor, @@ -227,9 +239,43 @@ class Solver(fl.Module, ABC): timestep = self.timesteps[step] cumulative_scale_factors = self.cumulative_scale_factors[timestep] noise_stds = self.noise_std[timestep] + + # noisify the latents, arXiv:2006.11239 Eq. 4 noised_x = cumulative_scale_factors * x + noise_stds * noise return noised_x + def add_noise( + self, + x: Tensor, + noise: Tensor, + step: int | list[int], + ) -> Tensor: + """Add noise to the input tensor using the solver's parameters. + + Args: + x: The input tensor to add noise to. + noise: The noise tensor to add to the input tensor. + step: The current step(s) of the diffusion process. + + Returns: + The input tensor with added noise. + """ + if isinstance(step, list): + assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length" + return stack( + tensors=[ + self._add_noise( + x=x[i], + noise=noise[i], + step=step[i], + ) + for i in range(x.shape[0]) + ], + dim=0, + ) + + return self._add_noise(x=x, noise=noise, step=step) + def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor: """Remove noise from the input tensor using the current step of the diffusion process. diff --git a/tests/e2e/test_lightning_ref/expected_lightning_base_4step.png b/tests/e2e/test_lightning_ref/expected_lightning_base_4step.png index 10ecf26..63df54c 100644 Binary files a/tests/e2e/test_lightning_ref/expected_lightning_base_4step.png and b/tests/e2e/test_lightning_ref/expected_lightning_base_4step.png differ diff --git a/tests/foundationals/latent_diffusion/test_model.py b/tests/foundationals/latent_diffusion/test_model.py new file mode 100644 index 0000000..78443d4 --- /dev/null +++ b/tests/foundationals/latent_diffusion/test_model.py @@ -0,0 +1,14 @@ +import torch + +from refiners.fluxion.utils import manual_seed, no_grad +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel + + +@no_grad() +def test_sample_noise(): + manual_seed(2) + latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64)) + manual_seed(2) + latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0) + + assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0) diff --git a/tests/foundationals/latent_diffusion/test_solvers.py b/tests/foundationals/latent_diffusion/test_solvers.py index 2065806..c653c48 100644 --- a/tests/foundationals/latent_diffusion/test_solvers.py +++ b/tests/foundationals/latent_diffusion/test_solvers.py @@ -198,6 +198,24 @@ def test_solver_device(test_device: Device): assert noised.device == test_device +def test_solver_add_noise(test_device: Device): + scheduler = DDIM(num_inference_steps=30, device=test_device) + latent = randn(1, 4, 32, 32, device=test_device) + noise = randn(1, 4, 32, 32, device=test_device) + noised = scheduler.add_noise( + x=latent, + noise=noise, + step=0, + ) + noised_double = scheduler.add_noise( + x=latent.repeat(2, 1, 1, 1), + noise=noise.repeat(2, 1, 1, 1), + step=[0, 0], + ) + assert allclose(noised, noised_double[0]) + assert allclose(noised, noised_double[1]) + + @pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS]) def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device): scheduler = DDIM(