add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.diffuse_targets "stateless"
Some checks failed
CI / lint_and_typecheck (push) Has been cancelled
Deploy docs to GitHub Pages / Deploy docs (push) Has been cancelled
Spell checker / Spell check (push) Has been cancelled

This commit is contained in:
Laurent 2024-09-26 09:12:25 +00:00 committed by Laureηt
parent 883a2121f2
commit 5a92285104
2 changed files with 16 additions and 9 deletions

View file

@ -8,7 +8,7 @@ from PIL import Image
from torch import Tensor
from typing_extensions import TypeVar
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
@ -217,13 +217,12 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
def diffuse_targets(
self,
noise: torch.Tensor,
targets: Sequence[T],
image: Image.Image,
latent_size: Size,
first_step: int,
autoencoder_tile_length: int,
) -> Image.Image:
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
latents = self.sd.lda.tiled_image_to_latents(image)
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
@ -249,7 +248,7 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
solver_type: type[Solver] = DPMSolver,
num_inference_steps: int = 18,
autoencoder_tile_length: int = 1024,
seed: int = 37,
generator: torch.Generator | None = None,
) -> Image.Image:
"""
Upscale an image using the multi upscaler.
@ -280,10 +279,8 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
between quality and speed.
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
lowering it can reduce GPU memory usage (but increase computation time).
seed: The seed to use for the random number generator.
generator: The random number generator to use for sampling noise.
"""
manual_seed(seed)
# update controlnet scale
self.controlnet.scale = controlnet_scale
self.controlnet.scale_decay = controlnet_scale_decay
@ -323,11 +320,19 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
clip_text_embedding=clip_text_embedding,
)
# initialize the noise
noise = torch.randn(
size=(1, 4, *latent_size),
device=self.device,
dtype=self.dtype,
generator=generator,
)
# diffuse the tiles
return self.diffuse_targets(
noise=noise,
targets=targets,
image=image,
latent_size=latent_size,
first_step=first_step,
autoencoder_tile_length=autoencoder_tile_length,
)

View file

@ -2669,7 +2669,9 @@ def test_multi_upscaler(
clarity_example: Image.Image,
expected_multi_upscaler: Image.Image,
) -> None:
predicted_image = multi_upscaler.upscale(clarity_example)
generator = torch.Generator(device=multi_upscaler.device)
generator.manual_seed(37)
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)