From 5a922851047aef20d1b1b8903ec0b696ee8e0c92 Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 26 Sep 2024 09:12:25 +0000 Subject: [PATCH] add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.diffuse_targets "stateless" --- .../stable_diffusion_1/multi_upscaler.py | 21 ++++++++++++------- tests/e2e/test_diffusion.py | 4 +++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_upscaler.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_upscaler.py index 5f110c5..779f032 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_upscaler.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/multi_upscaler.py @@ -8,7 +8,7 @@ from PIL import Image from torch import Tensor from typing_extensions import TypeVar -from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad +from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size @@ -217,13 +217,12 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC): def diffuse_targets( self, + noise: torch.Tensor, targets: Sequence[T], image: Image.Image, - latent_size: Size, first_step: int, autoencoder_tile_length: int, ) -> Image.Image: - noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype) with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)): latents = self.sd.lda.tiled_image_to_latents(image) x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step) @@ -249,7 +248,7 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC): solver_type: type[Solver] = DPMSolver, num_inference_steps: int = 18, autoencoder_tile_length: int = 1024, - seed: int = 37, + generator: torch.Generator | None = None, ) -> Image.Image: """ Upscale an image using the multi upscaler. @@ -280,10 +279,8 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC): between quality and speed. autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but lowering it can reduce GPU memory usage (but increase computation time). - seed: The seed to use for the random number generator. + generator: The random number generator to use for sampling noise. """ - manual_seed(seed) - # update controlnet scale self.controlnet.scale = controlnet_scale self.controlnet.scale_decay = controlnet_scale_decay @@ -323,11 +320,19 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC): clip_text_embedding=clip_text_embedding, ) + # initialize the noise + noise = torch.randn( + size=(1, 4, *latent_size), + device=self.device, + dtype=self.dtype, + generator=generator, + ) + # diffuse the tiles return self.diffuse_targets( + noise=noise, targets=targets, image=image, - latent_size=latent_size, first_step=first_step, autoencoder_tile_length=autoencoder_tile_length, ) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index dbcd16e..3315520 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -2669,7 +2669,9 @@ def test_multi_upscaler( clarity_example: Image.Image, expected_multi_upscaler: Image.Image, ) -> None: - predicted_image = multi_upscaler.upscale(clarity_example) + generator = torch.Generator(device=multi_upscaler.device) + generator.manual_seed(37) + predicted_image = multi_upscaler.upscale(clarity_example, generator=generator) ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)