add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.diffuse_targets "stateless"
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled

This commit is contained in:
Laurent 2024-09-26 09:12:25 +00:00 committed by Laureηt
parent 883a2121f2
commit 5a92285104
2 changed files with 16 additions and 9 deletions

View file

@ -8,7 +8,7 @@ from PIL import Image
from torch import Tensor from torch import Tensor
from typing_extensions import TypeVar from typing_extensions import TypeVar
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
@ -217,13 +217,12 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
def diffuse_targets( def diffuse_targets(
self, self,
noise: torch.Tensor,
targets: Sequence[T], targets: Sequence[T],
image: Image.Image, image: Image.Image,
latent_size: Size,
first_step: int, first_step: int,
autoencoder_tile_length: int, autoencoder_tile_length: int,
) -> Image.Image: ) -> Image.Image:
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)): with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
latents = self.sd.lda.tiled_image_to_latents(image) latents = self.sd.lda.tiled_image_to_latents(image)
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step) x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
@ -249,7 +248,7 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
solver_type: type[Solver] = DPMSolver, solver_type: type[Solver] = DPMSolver,
num_inference_steps: int = 18, num_inference_steps: int = 18,
autoencoder_tile_length: int = 1024, autoencoder_tile_length: int = 1024,
seed: int = 37, generator: torch.Generator | None = None,
) -> Image.Image: ) -> Image.Image:
""" """
Upscale an image using the multi upscaler. Upscale an image using the multi upscaler.
@ -280,10 +279,8 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
between quality and speed. between quality and speed.
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
lowering it can reduce GPU memory usage (but increase computation time). lowering it can reduce GPU memory usage (but increase computation time).
seed: The seed to use for the random number generator. generator: The random number generator to use for sampling noise.
""" """
manual_seed(seed)
# update controlnet scale # update controlnet scale
self.controlnet.scale = controlnet_scale self.controlnet.scale = controlnet_scale
self.controlnet.scale_decay = controlnet_scale_decay self.controlnet.scale_decay = controlnet_scale_decay
@ -323,11 +320,19 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
) )
# initialize the noise
noise = torch.randn(
size=(1, 4, *latent_size),
device=self.device,
dtype=self.dtype,
generator=generator,
)
# diffuse the tiles # diffuse the tiles
return self.diffuse_targets( return self.diffuse_targets(
noise=noise,
targets=targets, targets=targets,
image=image, image=image,
latent_size=latent_size,
first_step=first_step, first_step=first_step,
autoencoder_tile_length=autoencoder_tile_length, autoencoder_tile_length=autoencoder_tile_length,
) )

View file

@ -2669,7 +2669,9 @@ def test_multi_upscaler(
clarity_example: Image.Image, clarity_example: Image.Image,
expected_multi_upscaler: Image.Image, expected_multi_upscaler: Image.Image,
) -> None: ) -> None:
predicted_image = multi_upscaler.upscale(clarity_example) generator = torch.Generator(device=multi_upscaler.device)
generator.manual_seed(37)
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99) ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)