diff --git a/src/refiners/foundationals/latent_diffusion/restart.py b/src/refiners/foundationals/latent_diffusion/restart.py new file mode 100644 index 0000000..c436939 --- /dev/null +++ b/src/refiners/foundationals/latent_diffusion/restart.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass +from functools import cached_property +from typing import Generic, TypeVar + +import torch + +from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel +from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM +from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler + +T = TypeVar("T", bound=LatentDiffusionModel) + + +def add_noise_interval( + scheduler: Scheduler, + /, + x: torch.Tensor, + noise: torch.Tensor, + initial_timestep: torch.Tensor, + target_timestep: torch.Tensor, +) -> torch.Tensor: + initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep] + target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep] + + factor = target_cumulative_scale_factors / initial_cumulative_scale_factors + noised_x = factor * x + torch.sqrt(1 - factor**2) * noise + return noised_x + + +@dataclass +class Restart(Generic[T]): + """ + Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes" + (https://arxiv.org/pdf/2306.14878.pdf) + + Works only with the DDIM scheduler for now. + """ + + ldm: T + num_steps: int = 10 + num_iterations: int = 2 + start_time: float = 0.1 + end_time: float = 2 + + def __post_init__(self) -> None: + assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler" + + def __call__( + self, + x: torch.Tensor, + /, + clip_text_embedding: torch.Tensor, + condition_scale: float = 7.5, + **kwargs: torch.Tensor, + ) -> torch.Tensor: + original_scheduler = self.ldm.scheduler + new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype) + new_scheduler.timesteps = self.timesteps + self.ldm.scheduler = new_scheduler + + for _ in range(self.num_iterations): + noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype) + x = add_noise_interval( + new_scheduler, + x=x, + noise=noise, + initial_timestep=self.timesteps[-1], + target_timestep=self.timesteps[0], + ) + + for step in range(len(self.timesteps) - 1): + x = self.ldm( + x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs + ) + + self.ldm.scheduler = original_scheduler + + return x + + @cached_property + def start_step(self) -> int: + sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors + return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time))) + + @cached_property + def end_timestep(self) -> int: + sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors + return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time))) + + @cached_property + def timesteps(self) -> torch.Tensor: + return ( + torch.round( + torch.linspace( + start=int(self.ldm.scheduler.timesteps[self.start_step]), + end=self.end_timestep, + steps=self.num_steps, + ) + ) + .flip(0) + .to(device=self.device, dtype=torch.int64) + ) + + @property + def device(self) -> torch.device: + return self.ldm.device + + @property + def dtype(self) -> torch.dtype: + return self.ldm.dtype diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py index 8dd2e25..f873d53 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/ddim.py @@ -1,4 +1,4 @@ -from torch import Tensor, device as Device, arange, sqrt +from torch import Tensor, device as Device, dtype as Dtype, arange, sqrt, float32, tensor from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler @@ -10,8 +10,16 @@ class DDIM(Scheduler): initial_diffusion_rate: float = 8.5e-4, final_diffusion_rate: float = 1.2e-2, device: Device | str = "cpu", + dtype: Dtype = float32, ) -> None: - super().__init__(num_inference_steps, num_train_timesteps, initial_diffusion_rate, final_diffusion_rate, device) + super().__init__( + num_inference_steps, + num_train_timesteps, + initial_diffusion_rate, + final_diffusion_rate, + device=device, + dtype=dtype, + ) self.timesteps = self._generate_timesteps() def _generate_timesteps(self) -> Tensor: @@ -26,7 +34,11 @@ class DDIM(Scheduler): def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor: timestep, previous_timestep = ( self.timesteps[step], - self.timesteps[step] - self.num_train_timesteps // self.num_inference_steps, + ( + self.timesteps[step + 1] + if step < self.num_inference_steps - 1 + else tensor(data=[0], device=self.device, dtype=self.dtype) + ), ) current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], ( self.cumulative_scale_factors[previous_timestep] diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index c5e8473..fca8dfa 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -20,6 +20,7 @@ from refiners.foundationals.latent_diffusion import ( ) from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget +from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.clip.concepts import ConceptExtender @@ -221,6 +222,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image: return Image.open(fp=ref_path / "expected_multi_diffusion.png").convert(mode="RGB") +@pytest.fixture +def expected_restart(ref_path: Path) -> Image.Image: + return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB") + + @pytest.fixture def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor: return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")[""] # type: ignore @@ -1558,3 +1564,43 @@ def test_t2i_adapter_xl_canny( predicted_image = sdxl.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image) + + +@torch.no_grad() +def test_restart( + sd15_ddim: StableDiffusion_1, + expected_restart: Image.Image, + test_device: torch.device, +): + sd15 = sd15_ddim + n_steps = 30 + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_num_inference_steps(n_steps) + restart = Restart(ldm=sd15) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=test_device) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=8, + ) + + if step == restart.start_step: + x = restart( + x, + clip_text_embedding=clip_text_embedding, + condition_scale=8, + ) + + predicted_image = sd15.lda.decode_latents(x) + + ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) diff --git a/tests/e2e/test_diffusion_ref/README.md b/tests/e2e/test_diffusion_ref/README.md index f1feae8..62cf504 100644 --- a/tests/e2e/test_diffusion_ref/README.md +++ b/tests/e2e/test_diffusion_ref/README.md @@ -44,6 +44,7 @@ Special cases: - `expected_t2i_adapter_xl_canny.png` - `expected_image_sdxl_ip_adapter_plus_woman.png` - `expected_cutecat_sdxl_ddim_random_init_sag.png` + - `expected_restart.png` ## Other images diff --git a/tests/e2e/test_diffusion_ref/expected_restart.png b/tests/e2e/test_diffusion_ref/expected_restart.png new file mode 100644 index 0000000..bf19885 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_restart.png differ