mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.diffuse_targets "stateless"
This commit is contained in:
parent
883a2121f2
commit
5a92285104
|
@ -8,7 +8,7 @@ from PIL import Image
|
|||
from torch import Tensor
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
|
||||
|
@ -217,13 +217,12 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
|||
|
||||
def diffuse_targets(
|
||||
self,
|
||||
noise: torch.Tensor,
|
||||
targets: Sequence[T],
|
||||
image: Image.Image,
|
||||
latent_size: Size,
|
||||
first_step: int,
|
||||
autoencoder_tile_length: int,
|
||||
) -> Image.Image:
|
||||
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
|
||||
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
|
||||
latents = self.sd.lda.tiled_image_to_latents(image)
|
||||
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
|
||||
|
@ -249,7 +248,7 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
|||
solver_type: type[Solver] = DPMSolver,
|
||||
num_inference_steps: int = 18,
|
||||
autoencoder_tile_length: int = 1024,
|
||||
seed: int = 37,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Upscale an image using the multi upscaler.
|
||||
|
@ -280,10 +279,8 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
|||
between quality and speed.
|
||||
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
|
||||
lowering it can reduce GPU memory usage (but increase computation time).
|
||||
seed: The seed to use for the random number generator.
|
||||
generator: The random number generator to use for sampling noise.
|
||||
"""
|
||||
manual_seed(seed)
|
||||
|
||||
# update controlnet scale
|
||||
self.controlnet.scale = controlnet_scale
|
||||
self.controlnet.scale_decay = controlnet_scale_decay
|
||||
|
@ -323,11 +320,19 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
|||
clip_text_embedding=clip_text_embedding,
|
||||
)
|
||||
|
||||
# initialize the noise
|
||||
noise = torch.randn(
|
||||
size=(1, 4, *latent_size),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# diffuse the tiles
|
||||
return self.diffuse_targets(
|
||||
noise=noise,
|
||||
targets=targets,
|
||||
image=image,
|
||||
latent_size=latent_size,
|
||||
first_step=first_step,
|
||||
autoencoder_tile_length=autoencoder_tile_length,
|
||||
)
|
||||
|
|
|
@ -2669,7 +2669,9 @@ def test_multi_upscaler(
|
|||
clarity_example: Image.Image,
|
||||
expected_multi_upscaler: Image.Image,
|
||||
) -> None:
|
||||
predicted_image = multi_upscaler.upscale(clarity_example)
|
||||
generator = torch.Generator(device=multi_upscaler.device)
|
||||
generator.manual_seed(37)
|
||||
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
|
||||
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue