mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
implement Restart method for latent diffusion
This commit is contained in:
parent
e35dce825f
commit
7a62049d54
110
src/refiners/foundationals/latent_diffusion/restart.py
Normal file
110
src/refiners/foundationals/latent_diffusion/restart.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
|
||||
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
T = TypeVar("T", bound=LatentDiffusionModel)
|
||||
|
||||
|
||||
def add_noise_interval(
|
||||
scheduler: Scheduler,
|
||||
/,
|
||||
x: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
initial_timestep: torch.Tensor,
|
||||
target_timestep: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
|
||||
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
|
||||
|
||||
factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
|
||||
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
|
||||
return noised_x
|
||||
|
||||
|
||||
@dataclass
|
||||
class Restart(Generic[T]):
|
||||
"""
|
||||
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
|
||||
(https://arxiv.org/pdf/2306.14878.pdf)
|
||||
|
||||
Works only with the DDIM scheduler for now.
|
||||
"""
|
||||
|
||||
ldm: T
|
||||
num_steps: int = 10
|
||||
num_iterations: int = 2
|
||||
start_time: float = 0.1
|
||||
end_time: float = 2
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
/,
|
||||
clip_text_embedding: torch.Tensor,
|
||||
condition_scale: float = 7.5,
|
||||
**kwargs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
original_scheduler = self.ldm.scheduler
|
||||
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
|
||||
new_scheduler.timesteps = self.timesteps
|
||||
self.ldm.scheduler = new_scheduler
|
||||
|
||||
for _ in range(self.num_iterations):
|
||||
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
|
||||
x = add_noise_interval(
|
||||
new_scheduler,
|
||||
x=x,
|
||||
noise=noise,
|
||||
initial_timestep=self.timesteps[-1],
|
||||
target_timestep=self.timesteps[0],
|
||||
)
|
||||
|
||||
for step in range(len(self.timesteps) - 1):
|
||||
x = self.ldm(
|
||||
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
|
||||
)
|
||||
|
||||
self.ldm.scheduler = original_scheduler
|
||||
|
||||
return x
|
||||
|
||||
@cached_property
|
||||
def start_step(self) -> int:
|
||||
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
|
||||
|
||||
@cached_property
|
||||
def end_timestep(self) -> int:
|
||||
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
|
||||
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))
|
||||
|
||||
@cached_property
|
||||
def timesteps(self) -> torch.Tensor:
|
||||
return (
|
||||
torch.round(
|
||||
torch.linspace(
|
||||
start=int(self.ldm.scheduler.timesteps[self.start_step]),
|
||||
end=self.end_timestep,
|
||||
steps=self.num_steps,
|
||||
)
|
||||
)
|
||||
.flip(0)
|
||||
.to(device=self.device, dtype=torch.int64)
|
||||
)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.ldm.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.ldm.dtype
|
|
@ -1,4 +1,4 @@
|
|||
from torch import Tensor, device as Device, arange, sqrt
|
||||
from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor
|
||||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
|
||||
|
||||
|
||||
|
@ -10,8 +10,16 @@ class DDIM(Scheduler):
|
|||
initial_diffusion_rate: float = 8.5e-4,
|
||||
final_diffusion_rate: float = 1.2e-2,
|
||||
device: Device | str = "cpu",
|
||||
dtype: Dtype = float32,
|
||||
) -> None:
|
||||
super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device)
|
||||
super().__init__(
|
||||
num_inference_steps,
|
||||
num_train_timesteps,
|
||||
initial_diffusion_rate,
|
||||
final_diffusion_rate,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.timesteps = self._generate_timesteps()
|
||||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
|
@ -26,7 +34,11 @@ class DDIM(Scheduler):
|
|||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
timestep, previous_timestep = (
|
||||
self.timesteps[step],
|
||||
self.timesteps[step] - self.num_train_timesteps // self.num_inference_steps,
|
||||
(
|
||||
self.timesteps[step + 1]
|
||||
if step < self.num_inference_steps - 1
|
||||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||
),
|
||||
)
|
||||
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
||||
self.cumulative_scale_factors[previous_timestep]
|
||||
|
|
|
@ -20,6 +20,7 @@ from refiners.foundationals.latent_diffusion import (
|
|||
)
|
||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||
from refiners.foundationals.latent_diffusion.schedulers import DDIM
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
|
@ -221,6 +222,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image:
|
|||
return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_restart(ref_path: Path) -> Image.Image:
|
||||
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||
|
@ -1558,3 +1564,43 @@ def test_t2i_adapter_xl_canny(
|
|||
predicted_image = sdxl.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_restart(
|
||||
sd15_ddim: StableDiffusion_1,
|
||||
expected_restart: Image.Image,
|
||||
test_device: torch.device,
|
||||
):
|
||||
sd15 = sd15_ddim
|
||||
n_steps = 30
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_num_inference_steps(n_steps)
|
||||
restart = Restart(ldm=sd15)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=8,
|
||||
)
|
||||
|
||||
if step == restart.start_step:
|
||||
x = restart(
|
||||
x,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=8,
|
||||
)
|
||||
|
||||
predicted_image = sd15.lda.decode_latents(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
||||
|
|
|
@ -44,6 +44,7 @@ Special cases:
|
|||
- `expected_t2i_adapter_xl_canny.png`
|
||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||
- `expected_restart.png`
|
||||
|
||||
## Other images
|
||||
|
||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_restart.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_restart.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 504 KiB |
Loading…
Reference in a new issue