improve/add MultiDiffusion and MultiUpscaler e2e tests

Co-authored-by: limiteinductive <benjamin@lagon.tech>
Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
Laurent 2024-07-11 13:05:15 +00:00 committed by Laureηt
parent f3b5c8d3e1
commit f44ae150a7
6 changed files with 134 additions and 9 deletions

View file

@ -25,11 +25,19 @@ from refiners.foundationals.latent_diffusion import (
StableDiffusion_1_Inpainting, StableDiffusion_1_Inpainting,
) )
from refiners.foundationals.latent_diffusion.lora import SDLoraManager from refiners.foundationals.latent_diffusion.lora import SDLoraManager
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget from refiners.foundationals.latent_diffusion.multi_diffusion import Size, Tile
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import (
SD1DiffusionTarget,
SD1MultiDiffusion,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
@ -406,6 +414,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB") return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
@pytest.fixture
def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB")
@pytest.fixture @pytest.fixture
def expected_restart(ref_path: Path) -> Image.Image: def expected_restart(ref_path: Path) -> Image.Image:
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB") return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
@ -756,6 +769,41 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X
) )
@pytest.fixture(scope="module")
def multi_upscaler(
test_weights_path: Path,
unet_weights_std: Path,
text_encoder_weights: Path,
lda_ft_mse_weights: Path,
test_device: torch.device,
) -> MultiUpscaler:
controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
if not controlnet_tile_weights.is_file():
warn(message=f"could not find weights at {controlnet_tile_weights}, skipping")
pytest.skip(allow_module_level=True)
return MultiUpscaler(
checkpoints=UpscalerCheckpoints(
unet=unet_weights_std,
clip_text_encoder=text_encoder_weights,
lda=lda_ft_mse_weights,
controlnet_tile=controlnet_tile_weights,
),
device=test_device,
dtype=torch.float32,
)
@pytest.fixture(scope="module")
def clarity_example(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "clarity_input_example.png")
@pytest.fixture(scope="module")
def expected_multi_upscaler(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_multi_upscaler.png")
@no_grad() @no_grad()
def test_diffusion_std_random_init( def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
@ -2132,15 +2180,15 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
sd = sd15_ddim sd = sd15_ddim
multi_diffusion = SD1MultiDiffusion(sd) multi_diffusion = SD1MultiDiffusion(sd)
clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain") clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain")
target_1 = DiffusionTarget( # DDIM doesn't have an internal state, so we can share the same solver for all targets
size=(64, 64), target_1 = SD1DiffusionTarget(
offset=(0, 0), tile=Tile(top=0, left=0, bottom=64, right=64),
solver=sd.solver,
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
start_step=0,
) )
target_2 = DiffusionTarget( target_2 = SD1DiffusionTarget(
size=(64, 64), solver=sd.solver,
offset=(0, 16), tile=Tile(top=0, left=16, bottom=64, right=80),
clip_text_embedding=clip_text_embedding, clip_text_embedding=clip_text_embedding,
condition_scale=3, condition_scale=3,
start_step=0, start_step=0,
@ -2158,6 +2206,35 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@no_grad()
def test_multi_diffusion_dpm(sd15_std: StableDiffusion_1, expected_multi_diffusion_dpm: Image.Image) -> None:
manual_seed(seed=2)
sd = sd15_std
multi_diffusion = SD1MultiDiffusion(sd)
clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain")
tiles = SD1MultiDiffusion.generate_latent_tiles(size=Size(112, 196), tile_size=Size(96, 64), min_overlap=12)
targets = [
SD1DiffusionTarget(
tile=tile,
solver=DPMSolver(num_inference_steps=sd.solver.num_inference_steps, device=sd.device),
clip_text_embedding=clip_text_embedding,
)
for tile in tiles
]
noise = torch.randn(1, 4, 112, 196, device=sd.device, dtype=sd.dtype)
x = noise
for step in sd.steps:
x = multi_diffusion(
x,
noise=noise,
step=step,
targets=targets,
)
result = sd.lda.latents_to_image(x=x)
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion_dpm, min_psnr=35, min_ssim=0.98)
@no_grad() @no_grad()
def test_t2i_adapter_depth( def test_t2i_adapter_depth(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
@ -2427,3 +2504,13 @@ def test_style_aligned(
# compare against reference image # compare against reference image
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99) ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
@no_grad()
def test_multi_upscaler(
multi_upscaler: MultiUpscaler,
clarity_example: Image.Image,
expected_multi_upscaler: Image.Image,
) -> None:
predicted_image = multi_upscaler.upscale(clarity_example)
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)

View file

@ -58,6 +58,8 @@ Special cases:
- `expected_controllora_disabled.png` - `expected_controllora_disabled.png`
- `expected_style_aligned.png` - `expected_style_aligned.png`
- `expected_controlnet_canny_scale_decay.png` - `expected_controlnet_canny_scale_decay.png`
- `expected_multi_diffusion_dpm.png`
- `expected_multi_upscaler.png`
## Other images ## Other images
@ -92,6 +94,8 @@ Special cases:
- `low_res_dog.png` and `expected_controlnet_tile.png` are taken from Diffusers [documentation](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/tree/main/images), respectively named - `low_res_dog.png` and `expected_controlnet_tile.png` are taken from Diffusers [documentation](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/tree/main/images), respectively named
`original.png` and `output.png`. `original.png` and `output.png`.
- `clarity_input_example.png` is taken from the [Replicate demo](https://replicate.com/philz1337x/clarity-upscaler/examples) of the Clarity upscaler.
## VAE without randomness ## VAE without randomness
```diff ```diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 949 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

View file

@ -0,0 +1,34 @@
import pytest
from refiners.foundationals.latent_diffusion.multi_diffusion import MultiDiffusion, Size
def test_generate_latent_tiles() -> None:
size = Size(height=128, width=128)
tile_size = Size(height=32, width=32)
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size)
assert len(tiles) == 25
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=0)
assert len(tiles) == 16
size = Size(height=100, width=200)
tile_size = Size(height=32, width=32)
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=2)
assert len(tiles) == 28
def test_generate_latent_tiles_small_size() -> None:
# Test when the size is smaller than the tile size
size = Size(height=32, width=32)
tile_size = Size(height=64, width=64)
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size)
assert len(tiles) == 1
assert Size(tiles[0].bottom - tiles[0].top, tiles[0].right - tiles[0].left) == size
def test_overlap_larger_tile_size() -> None:
with pytest.raises(AssertionError):
size = Size(height=128, width=128)
tile_size = Size(height=32, width=32)
MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=32)