implement Restart method for latent diffusion

This commit is contained in:
limiteinductive 2023-10-12 15:04:57 +02:00 committed by Benjamin Trom
parent e35dce825f
commit 7a62049d54
5 changed files with 172 additions and 3 deletions

View file

@ -0,0 +1,110 @@
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, TypeVar
import torch
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
T = TypeVar("T", bound=LatentDiffusionModel)
def add_noise_interval(
scheduler: Scheduler,
/,
x: torch.Tensor,
noise: torch.Tensor,
initial_timestep: torch.Tensor,
target_timestep: torch.Tensor,
) -> torch.Tensor:
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
return noised_x
@dataclass
class Restart(Generic[T]):
"""
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
(https://arxiv.org/pdf/2306.14878.pdf)
Works only with the DDIM scheduler for now.
"""
ldm: T
num_steps: int = 10
num_iterations: int = 2
start_time: float = 0.1
end_time: float = 2
def __post_init__(self) -> None:
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
def __call__(
self,
x: torch.Tensor,
/,
clip_text_embedding: torch.Tensor,
condition_scale: float = 7.5,
**kwargs: torch.Tensor,
) -> torch.Tensor:
original_scheduler = self.ldm.scheduler
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
new_scheduler.timesteps = self.timesteps
self.ldm.scheduler = new_scheduler
for _ in range(self.num_iterations):
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
x = add_noise_interval(
new_scheduler,
x=x,
noise=noise,
initial_timestep=self.timesteps[-1],
target_timestep=self.timesteps[0],
)
for step in range(len(self.timesteps) - 1):
x = self.ldm(
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
)
self.ldm.scheduler = original_scheduler
return x
@cached_property
def start_step(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
@cached_property
def end_timestep(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
@cached_property
def timesteps(self) -> torch.Tensor:
return (
torch.round(
torch.linspace(
start=int(self.ldm.scheduler.timesteps[self.start_step]),
end=self.end_timestep,
steps=self.num_steps,
)
)
.flip(0)
.to(device=self.device, dtype=torch.int64)
)
@property
def device(self) -> torch.device:
return self.ldm.device
@property
def dtype(self) -> torch.dtype:
return self.ldm.dtype

View file

@ -1,4 +1,4 @@
from torch import Tensor, device as Device, arange, sqrt
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
@ -10,8 +10,16 @@ class DDIM(Scheduler):
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
device: Device | str = "cpu",
dtype: Dtype = float32,
) -> None:
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
super().__init__(
num_inference_steps,
num_train_timesteps,
initial_diffusion_rate,
final_diffusion_rate,
device=device,
dtype=dtype,
)
self.timesteps = self._generate_timesteps()
def _generate_timesteps(self) -> Tensor:
@ -26,7 +34,11 @@ class DDIM(Scheduler):
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step] - self.num_train_timesteps // self.num_inference_steps,
(
self.timesteps[step + 1]
if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
self.cumulative_scale_factors[previous_timestep]

View file

@ -20,6 +20,7 @@ from refiners.foundationals.latent_diffusion import (
)
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.clip.concepts import ConceptExtender
@ -221,6 +222,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
@pytest.fixture
def expected_restart(ref_path: Path) -> Image.Image:
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
@pytest.fixture
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
@ -1558,3 +1564,43 @@ def test_t2i_adapter_xl_canny(
predicted_image = sdxl.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_image)
@torch.no_grad()
def test_restart(
sd15_ddim: StableDiffusion_1,
expected_restart: Image.Image,
test_device: torch.device,
):
sd15 = sd15_ddim
n_steps = 30
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps)
restart = Restart(ldm=sd15)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=8,
)
if step == restart.start_step:
x = restart(
x,
clip_text_embedding=clip_text_embedding,
condition_scale=8,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)

View file

@ -44,6 +44,7 @@ Special cases:
- `expected_t2i_adapter_xl_canny.png`
- `expected_image_sdxl_ip_adapter_plus_woman.png`
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
- `expected_restart.png`
## Other images

Binary file not shown.

After

Width:  |  Height:  |  Size: 504 KiB