mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
implement Restart method for latent diffusion
This commit is contained in:
parent
e35dce825f
commit
7a62049d54
110
src/refiners/foundationals/latent_diffusion/restart.py
Normal file
110
src/refiners/foundationals/latent_diffusion/restart.py
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||||
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||||
|
|
||||||
|
|
||||||
|
def add_noise_interval(
|
||||||
|
scheduler: Scheduler,
|
||||||
|
/,
|
||||||
|
x: torch.Tensor,
|
||||||
|
noise: torch.Tensor,
|
||||||
|
initial_timestep: torch.Tensor,
|
||||||
|
target_timestep: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
|
||||||
|
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
|
||||||
|
|
||||||
|
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
|
||||||
|
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
|
||||||
|
return noised_x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Restart(Generic[T]):
|
||||||
|
"""
|
||||||
|
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
|
||||||
|
(https://arxiv.org/pdf/2306.14878.pdf)
|
||||||
|
|
||||||
|
Works only with the DDIM scheduler for now.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ldm: T
|
||||||
|
num_steps: int = 10
|
||||||
|
num_iterations: int = 2
|
||||||
|
start_time: float = 0.1
|
||||||
|
end_time: float = 2
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
/,
|
||||||
|
clip_text_embedding: torch.Tensor,
|
||||||
|
condition_scale: float = 7.5,
|
||||||
|
**kwargs: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
original_scheduler = self.ldm.scheduler
|
||||||
|
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
|
||||||
|
new_scheduler.timesteps = self.timesteps
|
||||||
|
self.ldm.scheduler = new_scheduler
|
||||||
|
|
||||||
|
for _ in range(self.num_iterations):
|
||||||
|
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
|
||||||
|
x = add_noise_interval(
|
||||||
|
new_scheduler,
|
||||||
|
x=x,
|
||||||
|
noise=noise,
|
||||||
|
initial_timestep=self.timesteps[-1],
|
||||||
|
target_timestep=self.timesteps[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
for step in range(len(self.timesteps) - 1):
|
||||||
|
x = self.ldm(
|
||||||
|
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ldm.scheduler = original_scheduler
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def start_step(self) -> int:
|
||||||
|
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||||
|
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def end_timestep(self) -> int:
|
||||||
|
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||||
|
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def timesteps(self) -> torch.Tensor:
|
||||||
|
return (
|
||||||
|
torch.round(
|
||||||
|
torch.linspace(
|
||||||
|
start=int(self.ldm.scheduler.timesteps[self.start_step]),
|
||||||
|
end=self.end_timestep,
|
||||||
|
steps=self.num_steps,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.flip(0)
|
||||||
|
.to(device=self.device, dtype=torch.int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self.ldm.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self.ldm.dtype
|
|
@ -1,4 +1,4 @@
|
||||||
from torch import Tensor, device as Device, arange, sqrt
|
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
|
||||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,8 +10,16 @@ class DDIM(Scheduler):
|
||||||
initial_diffusion_rate: float = 8.5e-4,
|
initial_diffusion_rate: float = 8.5e-4,
|
||||||
final_diffusion_rate: float = 1.2e-2,
|
final_diffusion_rate: float = 1.2e-2,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
|
dtype: Dtype = float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
|
super().__init__(
|
||||||
|
num_inference_steps,
|
||||||
|
num_train_timesteps,
|
||||||
|
initial_diffusion_rate,
|
||||||
|
final_diffusion_rate,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
self.timesteps = self._generate_timesteps()
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
|
@ -26,7 +34,11 @@ class DDIM(Scheduler):
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
timestep, previous_timestep = (
|
timestep, previous_timestep = (
|
||||||
self.timesteps[step],
|
self.timesteps[step],
|
||||||
self.timesteps[step] - self.num_train_timesteps // self.num_inference_steps,
|
(
|
||||||
|
self.timesteps[step + 1]
|
||||||
|
if step < self.num_inference_steps - 1
|
||||||
|
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
||||||
self.cumulative_scale_factors[previous_timestep]
|
self.cumulative_scale_factors[previous_timestep]
|
||||||
|
|
|
@ -20,6 +20,7 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||||
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
|
@ -221,6 +222,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_restart(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||||
|
@ -1558,3 +1564,43 @@ def test_t2i_adapter_xl_canny(
|
||||||
predicted_image = sdxl.lda.decode_latents(x)
|
predicted_image = sdxl.lda.decode_latents(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image)
|
ensure_similar_images(predicted_image, expected_image)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_restart(
|
||||||
|
sd15_ddim: StableDiffusion_1,
|
||||||
|
expected_restart: Image.Image,
|
||||||
|
test_device: torch.device,
|
||||||
|
):
|
||||||
|
sd15 = sd15_ddim
|
||||||
|
n_steps = 30
|
||||||
|
|
||||||
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
restart = Restart(ldm=sd15)
|
||||||
|
|
||||||
|
manual_seed(2)
|
||||||
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
|
for step in sd15.steps:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
if step == restart.start_step:
|
||||||
|
x = restart(
|
||||||
|
x,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
|
@ -44,6 +44,7 @@ Special cases:
|
||||||
- `expected_t2i_adapter_xl_canny.png`
|
- `expected_t2i_adapter_xl_canny.png`
|
||||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||||
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||||
|
- `expected_restart.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_restart.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_restart.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 504 KiB |
Loading…
Reference in a new issue