mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
Add sample_noise
staticmethod and modify add_noise
to support batched steps
This commit is contained in:
parent
7427c171f6
commit
17246708b9
|
@ -35,32 +35,68 @@ class LatentDiffusionModel(fl.Module, ABC):
|
|||
def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
|
||||
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
|
||||
|
||||
@staticmethod
|
||||
def sample_noise(
|
||||
size: tuple[int, ...],
|
||||
device: Device | None = None,
|
||||
dtype: DType | None = None,
|
||||
offset_noise: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Sample noise from a normal distribution with an optional offset.
|
||||
|
||||
Args:
|
||||
size: The size of the noise tensor.
|
||||
device: The device to put the noise tensor on.
|
||||
dtype: The data type of the noise tensor.
|
||||
offset_noise: The offset of the noise tensor.
|
||||
Useful at training time, see https://www.crosslabs.org/blog/diffusion-with-offset-noise.
|
||||
"""
|
||||
noise = torch.randn(size=size, device=device, dtype=dtype)
|
||||
if offset_noise is not None:
|
||||
noise += offset_noise * torch.randn(size=(size[0], size[1], 1, 1), device=device, dtype=dtype)
|
||||
return noise
|
||||
|
||||
def init_latents(
|
||||
self,
|
||||
size: tuple[int, int],
|
||||
init_image: Image.Image | None = None,
|
||||
noise: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Initialize the latents for the diffusion process.
|
||||
|
||||
Args:
|
||||
size: The size of the latent (in pixel space).
|
||||
init_image: The image to use as initialization for the latents.
|
||||
noise: The noise to add to the latents.
|
||||
"""
|
||||
height, width = size
|
||||
latent_height = height // 8
|
||||
latent_width = width // 8
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
|
||||
noise = LatentDiffusionModel.sample_noise(
|
||||
size=(1, 4, latent_height, latent_width),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
assert list(noise.shape[2:]) == [
|
||||
height // 8,
|
||||
width // 8,
|
||||
latent_height,
|
||||
latent_width,
|
||||
], f"noise shape is not compatible: {noise.shape}, with size: {size}"
|
||||
|
||||
if init_image is None:
|
||||
x = noise
|
||||
latent = noise
|
||||
else:
|
||||
resized = init_image.resize(size=(width, height)) # type: ignore
|
||||
encoded_image = self.lda.image_to_latents(resized)
|
||||
x = self.solver.add_noise(
|
||||
latent = self.solver.add_noise(
|
||||
x=encoded_image,
|
||||
noise=noise,
|
||||
step=self.solver.first_inference_step,
|
||||
)
|
||||
|
||||
return self.solver.scale_model_input(x, step=-1)
|
||||
return self.solver.scale_model_input(latent, step=-1)
|
||||
|
||||
@property
|
||||
def steps(self) -> list[int]:
|
||||
|
|
|
@ -4,7 +4,19 @@ from enum import Enum
|
|||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from torch import Generator, Tensor, arange, device as Device, dtype as DType, float32, linspace, log, sqrt, tensor
|
||||
from torch import (
|
||||
Generator,
|
||||
Tensor,
|
||||
arange,
|
||||
device as Device,
|
||||
dtype as DType,
|
||||
float32,
|
||||
linspace,
|
||||
log,
|
||||
sqrt,
|
||||
stack,
|
||||
tensor,
|
||||
)
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
|
||||
|
@ -208,7 +220,7 @@ class Solver(fl.Module, ABC):
|
|||
offset=self.params.timesteps_offset,
|
||||
)
|
||||
|
||||
def add_noise(
|
||||
def _add_noise(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
|
@ -227,9 +239,43 @@ class Solver(fl.Module, ABC):
|
|||
timestep = self.timesteps[step]
|
||||
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
|
||||
noise_stds = self.noise_std[timestep]
|
||||
|
||||
# noisify the latents, arXiv:2006.11239 Eq. 4
|
||||
noised_x = cumulative_scale_factors * x + noise_stds * noise
|
||||
return noised_x
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
x: Tensor,
|
||||
noise: Tensor,
|
||||
step: int | list[int],
|
||||
) -> Tensor:
|
||||
"""Add noise to the input tensor using the solver's parameters.
|
||||
|
||||
Args:
|
||||
x: The input tensor to add noise to.
|
||||
noise: The noise tensor to add to the input tensor.
|
||||
step: The current step(s) of the diffusion process.
|
||||
|
||||
Returns:
|
||||
The input tensor with added noise.
|
||||
"""
|
||||
if isinstance(step, list):
|
||||
assert len(x) == len(noise) == len(step), "x, noise, and step must have the same length"
|
||||
return stack(
|
||||
tensors=[
|
||||
self._add_noise(
|
||||
x=x[i],
|
||||
noise=noise[i],
|
||||
step=step[i],
|
||||
)
|
||||
for i in range(x.shape[0])
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return self._add_noise(x=x, noise=noise, step=step)
|
||||
|
||||
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
"""Remove noise from the input tensor using the current step of the diffusion process.
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 1.4 MiB After Width: | Height: | Size: 1.4 MiB |
14
tests/foundationals/latent_diffusion/test_model.py
Normal file
14
tests/foundationals/latent_diffusion/test_model.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
|
||||
from refiners.fluxion.utils import manual_seed, no_grad
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_sample_noise():
|
||||
manual_seed(2)
|
||||
latents_0 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64))
|
||||
manual_seed(2)
|
||||
latents_1 = LatentDiffusionModel.sample_noise(size=(1, 4, 64, 64), offset_noise=0.0)
|
||||
|
||||
assert torch.allclose(latents_0, latents_1, atol=1e-6, rtol=0)
|
|
@ -198,6 +198,24 @@ def test_solver_device(test_device: Device):
|
|||
assert noised.device == test_device
|
||||
|
||||
|
||||
def test_solver_add_noise(test_device: Device):
|
||||
scheduler = DDIM(num_inference_steps=30, device=test_device)
|
||||
latent = randn(1, 4, 32, 32, device=test_device)
|
||||
noise = randn(1, 4, 32, 32, device=test_device)
|
||||
noised = scheduler.add_noise(
|
||||
x=latent,
|
||||
noise=noise,
|
||||
step=0,
|
||||
)
|
||||
noised_double = scheduler.add_noise(
|
||||
x=latent.repeat(2, 1, 1, 1),
|
||||
noise=noise.repeat(2, 1, 1, 1),
|
||||
step=[0, 0],
|
||||
)
|
||||
assert allclose(noised, noised_double[0])
|
||||
assert allclose(noised, noised_double[1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("noise_schedule", [NoiseSchedule.UNIFORM, NoiseSchedule.QUADRATIC, NoiseSchedule.KARRAS])
|
||||
def test_solver_noise_schedules(noise_schedule: NoiseSchedule, test_device: Device):
|
||||
scheduler = DDIM(
|
||||
|
|
Loading…
Reference in a new issue