mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add torch.Generator to MultiUpscaler.upscale + make MultiUpscaler.diffuse_targets "stateless"
This commit is contained in:
parent
883a2121f2
commit
5a92285104
|
@ -8,7 +8,7 @@ from PIL import Image
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget, MultiDiffusion, Size
|
||||||
|
@ -217,13 +217,12 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
||||||
|
|
||||||
def diffuse_targets(
|
def diffuse_targets(
|
||||||
self,
|
self,
|
||||||
|
noise: torch.Tensor,
|
||||||
targets: Sequence[T],
|
targets: Sequence[T],
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
latent_size: Size,
|
|
||||||
first_step: int,
|
first_step: int,
|
||||||
autoencoder_tile_length: int,
|
autoencoder_tile_length: int,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
noise = torch.randn(size=(1, 4, *latent_size), device=self.device, dtype=self.dtype)
|
|
||||||
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
|
with self.sd.lda.tiled_inference(image, (autoencoder_tile_length, autoencoder_tile_length)):
|
||||||
latents = self.sd.lda.tiled_image_to_latents(image)
|
latents = self.sd.lda.tiled_image_to_latents(image)
|
||||||
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
|
x = self.sd.solver.add_noise(x=latents, noise=noise, step=first_step)
|
||||||
|
@ -249,7 +248,7 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
||||||
solver_type: type[Solver] = DPMSolver,
|
solver_type: type[Solver] = DPMSolver,
|
||||||
num_inference_steps: int = 18,
|
num_inference_steps: int = 18,
|
||||||
autoencoder_tile_length: int = 1024,
|
autoencoder_tile_length: int = 1024,
|
||||||
seed: int = 37,
|
generator: torch.Generator | None = None,
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
"""
|
"""
|
||||||
Upscale an image using the multi upscaler.
|
Upscale an image using the multi upscaler.
|
||||||
|
@ -280,10 +279,8 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
||||||
between quality and speed.
|
between quality and speed.
|
||||||
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
|
autoencoder_tile_length: The length of the autoencoder tiles. It shouldn't affect the end result, but
|
||||||
lowering it can reduce GPU memory usage (but increase computation time).
|
lowering it can reduce GPU memory usage (but increase computation time).
|
||||||
seed: The seed to use for the random number generator.
|
generator: The random number generator to use for sampling noise.
|
||||||
"""
|
"""
|
||||||
manual_seed(seed)
|
|
||||||
|
|
||||||
# update controlnet scale
|
# update controlnet scale
|
||||||
self.controlnet.scale = controlnet_scale
|
self.controlnet.scale = controlnet_scale
|
||||||
self.controlnet.scale_decay = controlnet_scale_decay
|
self.controlnet.scale_decay = controlnet_scale_decay
|
||||||
|
@ -323,11 +320,19 @@ class MultiUpscalerAbstract(MultiDiffusion[T], ABC):
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# initialize the noise
|
||||||
|
noise = torch.randn(
|
||||||
|
size=(1, 4, *latent_size),
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
|
|
||||||
# diffuse the tiles
|
# diffuse the tiles
|
||||||
return self.diffuse_targets(
|
return self.diffuse_targets(
|
||||||
|
noise=noise,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
image=image,
|
image=image,
|
||||||
latent_size=latent_size,
|
|
||||||
first_step=first_step,
|
first_step=first_step,
|
||||||
autoencoder_tile_length=autoencoder_tile_length,
|
autoencoder_tile_length=autoencoder_tile_length,
|
||||||
)
|
)
|
||||||
|
|
|
@ -2669,7 +2669,9 @@ def test_multi_upscaler(
|
||||||
clarity_example: Image.Image,
|
clarity_example: Image.Image,
|
||||||
expected_multi_upscaler: Image.Image,
|
expected_multi_upscaler: Image.Image,
|
||||||
) -> None:
|
) -> None:
|
||||||
predicted_image = multi_upscaler.upscale(clarity_example)
|
generator = torch.Generator(device=multi_upscaler.device)
|
||||||
|
generator.manual_seed(37)
|
||||||
|
predicted_image = multi_upscaler.upscale(clarity_example, generator=generator)
|
||||||
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue