mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Add sample_noise
staticmethod and modify add_noise
to support batched steps
This commit is contained in:
parent
7427c171f6
commit
17246708b9
|
@ -35,32 +35,68 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||||
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sample_noise(
|
||||||
|
size: tuple[int, ...],
|
||||||
|
device: Device | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
offset_noise: float | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Sample noise from a normal distribution with an optional offset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: The size of the noise tensor.
|
||||||
|
device: The device to put the noise tensor on.
|
||||||
|
dtype: The data type of the noise tensor.
|
||||||
|
offset_noise: The offset of the noise tensor.
|
||||||
|
Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.
|
||||||
|
"""
|
||||||
|
noise = torch.randn(size=size, device=device, dtype=dtype)
|
||||||
|
if offset_noise is not None:
|
||||||
|
noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype)
|
||||||
|
return noise
|
||||||
|
|
||||||
def init_latents(
|
def init_latents(
|
||||||
self,
|
self,
|
||||||
size: tuple[int, int],
|
size: tuple[int, int],
|
||||||
init_image: Image.Image | None = None,
|
init_image: Image.Image | None = None,
|
||||||
noise: Tensor | None = None,
|
noise: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
"""Initialize the latents for the diffusion process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: The size of the latent (in pixel space).
|
||||||
|
init_image: The image to use as initialization for the latents.
|
||||||
|
noise: The noise to add to the latents.
|
||||||
|
"""
|
||||||
height, width = size
|
height, width = size
|
||||||
|
latent_height = height // 8
|
||||||
|
latent_width = width // 8
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
noise = LatentDiffusionModel.sample_noise(
|
||||||
|
size=(1, 4, latent_height, latent_width),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
assert list(noise.shape[2:]) == [
|
assert list(noise.shape[2:]) == [
|
||||||
height // 8,
|
latent_height,
|
||||||
width // 8,
|
latent_width,
|
||||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||||
|
|
||||||
if init_image is None:
|
if init_image is None:
|
||||||
x = noise
|
latent = noise
|
||||||
else:
|
else:
|
||||||
resized = init_image.resize(size=(width, height)) # type: ignore
|
resized = init_image.resize(size=(width, height)) # type: ignore
|
||||||
encoded_image = self.lda.image_to_latents(resized)
|
encoded_image = self.lda.image_to_latents(resized)
|
||||||
x = self.solver.add_noise(
|
latent = self.solver.add_noise(
|
||||||
x=encoded_image,
|
x=encoded_image,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
step=self.solver.first_inference_step,
|
step=self.solver.first_inference_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.solver.scale_model_input(x, step=-1)
|
return self.solver.scale_model_input(latent, step=-1)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps(self) -> list[int]:
|
def steps(self) -> list[int]:
|
||||||
|
|
|
@ -4,7 +4,19 @@ from enum import Enum
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor
|
from torch import (
|
||||||
|
Generator,
|
||||||
|
Tensor,
|
||||||
|
arange,
|
||||||
|
device as Device,
|
||||||
|
dtype as DType,
|
||||||
|
float32,
|
||||||
|
linspace,
|
||||||
|
log,
|
||||||
|
sqrt,
|
||||||
|
stack,
|
||||||
|
tensor,
|
||||||
|
)
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
|
|
||||||
|
@ -208,7 +220,7 @@ class Solver(fl.Module, ABC):
|
||||||
offset=self.params.timesteps_offset,
|
offset=self.params.timesteps_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_noise(
|
def _add_noise(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
noise: Tensor,
|
noise: Tensor,
|
||||||
|
@ -227,9 +239,43 @@ class Solver(fl.Module, ABC):
|
||||||
timestep = self.timesteps[step]
|
timestep = self.timesteps[step]
|
||||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||||
noise_stds = self.noise_std[timestep]
|
noise_stds = self.noise_std[timestep]
|
||||||
|
|
||||||
|
# noisify the latents, arXiv:2006.11239 Eq. 4
|
||||||
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
||||||
return noised_x
|
return noised_x
|
||||||
|
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
noise: Tensor,
|
||||||
|
step: int | list[int],
|
||||||
|
) -> Tensor:
|
||||||
|
"""Add noise to the input tensor using the solver's parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: The input tensor to add noise to.
|
||||||
|
noise: The noise tensor to add to the input tensor.
|
||||||
|
step: The current step(s) of the diffusion process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The input tensor with added noise.
|
||||||
|
"""
|
||||||
|
if isinstance(step, list):
|
||||||
|
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
|
||||||
|
return stack(
|
||||||
|
tensors=[
|
||||||
|
self._add_noise(
|
||||||
|
x=x[i],
|
||||||
|
noise=noise[i],
|
||||||
|
step=step[i],
|
||||||
|
)
|
||||||
|
for i in range(x.shape[0])
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._add_noise(x=x, noise=noise, step=step)
|
||||||
|
|
||||||
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
"""Remove noise from the input tensor using the current step of the diffusion process.
|
"""Remove noise from the input tensor using the current step of the diffusion process.
|
||||||
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 1.4 MiB After Width: | Height: | Size: 1.4 MiB |
14
tests/foundationals/latent_diffusion/test_model.py
Normal file
14
tests/foundationals/latent_diffusion/test_model.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_sample_noise():
|
||||||
|
manual_seed(2)
|
||||||
|
latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64))
|
||||||
|
manual_seed(2)
|
||||||
|
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
|
||||||
|
|
||||||
|
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
|
@ -198,6 +198,24 @@ def test_solver_device(test_device: Device):
|
||||||
assert noised.device == test_device
|
assert noised.device == test_device
|
||||||
|
|
||||||
|
|
||||||
|
def test_solver_add_noise(test_device: Device):
|
||||||
|
scheduler = DDIM(num_inference_steps=30, device=test_device)
|
||||||
|
latent = randn(1, 4, 32, 32, device=test_device)
|
||||||
|
noise = randn(1, 4, 32, 32, device=test_device)
|
||||||
|
noised = scheduler.add_noise(
|
||||||
|
x=latent,
|
||||||
|
noise=noise,
|
||||||
|
step=0,
|
||||||
|
)
|
||||||
|
noised_double = scheduler.add_noise(
|
||||||
|
x=latent.repeat(2, 1, 1, 1),
|
||||||
|
noise=noise.repeat(2, 1, 1, 1),
|
||||||
|
step=[0, 0],
|
||||||
|
)
|
||||||
|
assert allclose(noised, noised_double[0])
|
||||||
|
assert allclose(noised, noised_double[1])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||||
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||||
scheduler = DDIM(
|
scheduler = DDIM(
|
||||||
|
|
Loading…
Reference in a new issue