mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
improve/add MultiDiffusion and MultiUpscaler e2e tests
Co-authored-by: limiteinductive <benjamin@lagon.tech> Co-authored-by: Cédric Deltheil <355031+deltheil@users.noreply.github.com>
This commit is contained in:
parent
f3b5c8d3e1
commit
f44ae150a7
|
@ -25,11 +25,19 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
StableDiffusion_1_Inpainting,
|
StableDiffusion_1_Inpainting,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
from refiners.foundationals.latent_diffusion.lora import SDLoraManager
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
from refiners.foundationals.latent_diffusion.multi_diffusion import Size, Tile
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
|
||||||
from refiners.foundationals.latent_diffusion.restart import Restart
|
from refiners.foundationals.latent_diffusion.restart import Restart
|
||||||
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
|
from refiners.foundationals.latent_diffusion.solvers import DDIM, Euler, NoiseSchedule, SolverParams
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
|
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import (
|
||||||
|
SD1DiffusionTarget,
|
||||||
|
SD1MultiDiffusion,
|
||||||
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
|
||||||
|
MultiUpscaler,
|
||||||
|
UpscalerCheckpoints,
|
||||||
|
)
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
from refiners.foundationals.latent_diffusion.style_aligned import StyleAlignedAdapter
|
||||||
|
|
||||||
|
@ -406,6 +414,11 @@ def expected_multi_diffusion(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_multi_diffusion.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_multi_diffusion_dpm(ref_path: Path) -> Image.Image:
|
||||||
|
return _img_open(ref_path / "expected_multi_diffusion_dpm.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def expected_restart(ref_path: Path) -> Image.Image:
|
def expected_restart(ref_path: Path) -> Image.Image:
|
||||||
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
|
return _img_open(ref_path / "expected_restart.png").convert(mode="RGB")
|
||||||
|
@ -756,6 +769,41 @@ def sdxl_euler_deterministic(sdxl_ddim: StableDiffusion_XL) -> StableDiffusion_X
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def multi_upscaler(
|
||||||
|
test_weights_path: Path,
|
||||||
|
unet_weights_std: Path,
|
||||||
|
text_encoder_weights: Path,
|
||||||
|
lda_ft_mse_weights: Path,
|
||||||
|
test_device: torch.device,
|
||||||
|
) -> MultiUpscaler:
|
||||||
|
controlnet_tile_weights = test_weights_path / "controlnet" / "lllyasviel_control_v11f1e_sd15_tile.safetensors"
|
||||||
|
if not controlnet_tile_weights.is_file():
|
||||||
|
warn(message=f"could not find weights at {controlnet_tile_weights}, skipping")
|
||||||
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
|
return MultiUpscaler(
|
||||||
|
checkpoints=UpscalerCheckpoints(
|
||||||
|
unet=unet_weights_std,
|
||||||
|
clip_text_encoder=text_encoder_weights,
|
||||||
|
lda=lda_ft_mse_weights,
|
||||||
|
controlnet_tile=controlnet_tile_weights,
|
||||||
|
),
|
||||||
|
device=test_device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def clarity_example(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "clarity_input_example.png")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def expected_multi_upscaler(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(ref_path / "expected_multi_upscaler.png")
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init(
|
def test_diffusion_std_random_init(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
|
@ -2132,15 +2180,15 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
||||||
sd = sd15_ddim
|
sd = sd15_ddim
|
||||||
multi_diffusion = SD1MultiDiffusion(sd)
|
multi_diffusion = SD1MultiDiffusion(sd)
|
||||||
clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain")
|
clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain")
|
||||||
target_1 = DiffusionTarget(
|
# DDIM doesn't have an internal state, so we can share the same solver for all targets
|
||||||
size=(64, 64),
|
target_1 = SD1DiffusionTarget(
|
||||||
offset=(0, 0),
|
tile=Tile(top=0, left=0, bottom=64, right=64),
|
||||||
|
solver=sd.solver,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
start_step=0,
|
|
||||||
)
|
)
|
||||||
target_2 = DiffusionTarget(
|
target_2 = SD1DiffusionTarget(
|
||||||
size=(64, 64),
|
solver=sd.solver,
|
||||||
offset=(0, 16),
|
tile=Tile(top=0, left=16, bottom=64, right=80),
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
condition_scale=3,
|
condition_scale=3,
|
||||||
start_step=0,
|
start_step=0,
|
||||||
|
@ -2158,6 +2206,35 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
||||||
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_multi_diffusion_dpm(sd15_std: StableDiffusion_1, expected_multi_diffusion_dpm: Image.Image) -> None:
|
||||||
|
manual_seed(seed=2)
|
||||||
|
sd = sd15_std
|
||||||
|
multi_diffusion = SD1MultiDiffusion(sd)
|
||||||
|
clip_text_embedding = sd.compute_clip_text_embedding(text="a panorama of a mountain")
|
||||||
|
tiles = SD1MultiDiffusion.generate_latent_tiles(size=Size(112, 196), tile_size=Size(96, 64), min_overlap=12)
|
||||||
|
targets = [
|
||||||
|
SD1DiffusionTarget(
|
||||||
|
tile=tile,
|
||||||
|
solver=DPMSolver(num_inference_steps=sd.solver.num_inference_steps, device=sd.device),
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
)
|
||||||
|
for tile in tiles
|
||||||
|
]
|
||||||
|
|
||||||
|
noise = torch.randn(1, 4, 112, 196, device=sd.device, dtype=sd.dtype)
|
||||||
|
x = noise
|
||||||
|
for step in sd.steps:
|
||||||
|
x = multi_diffusion(
|
||||||
|
x,
|
||||||
|
noise=noise,
|
||||||
|
step=step,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
result = sd.lda.latents_to_image(x=x)
|
||||||
|
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion_dpm, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@no_grad()
|
@no_grad()
|
||||||
def test_t2i_adapter_depth(
|
def test_t2i_adapter_depth(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
|
@ -2427,3 +2504,13 @@ def test_style_aligned(
|
||||||
|
|
||||||
# compare against reference image
|
# compare against reference image
|
||||||
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(merged_image, expected_style_aligned, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def test_multi_upscaler(
|
||||||
|
multi_upscaler: MultiUpscaler,
|
||||||
|
clarity_example: Image.Image,
|
||||||
|
expected_multi_upscaler: Image.Image,
|
||||||
|
) -> None:
|
||||||
|
predicted_image = multi_upscaler.upscale(clarity_example)
|
||||||
|
ensure_similar_images(predicted_image, expected_multi_upscaler, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
|
@ -58,6 +58,8 @@ Special cases:
|
||||||
- `expected_controllora_disabled.png`
|
- `expected_controllora_disabled.png`
|
||||||
- `expected_style_aligned.png`
|
- `expected_style_aligned.png`
|
||||||
- `expected_controlnet_canny_scale_decay.png`
|
- `expected_controlnet_canny_scale_decay.png`
|
||||||
|
- `expected_multi_diffusion_dpm.png`
|
||||||
|
- `expected_multi_upscaler.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
@ -92,6 +94,8 @@ Special cases:
|
||||||
- `low_res_dog.png` and `expected_controlnet_tile.png` are taken from Diffusers [documentation](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/tree/main/images), respectively named
|
- `low_res_dog.png` and `expected_controlnet_tile.png` are taken from Diffusers [documentation](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile/tree/main/images), respectively named
|
||||||
`original.png` and `output.png`.
|
`original.png` and `output.png`.
|
||||||
|
|
||||||
|
- `clarity_input_example.png` is taken from the [Replicate demo](https://replicate.com/philz1337x/clarity-upscaler/examples) of the Clarity upscaler.
|
||||||
|
|
||||||
## VAE without randomness
|
## VAE without randomness
|
||||||
|
|
||||||
```diff
|
```diff
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/clarity_input_example.png
Normal file
BIN
tests/e2e/test_diffusion_ref/clarity_input_example.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 949 KiB |
BIN
tests/e2e/test_diffusion_ref/expected_multi_diffusion_dpm.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_multi_diffusion_dpm.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.9 MiB |
BIN
tests/e2e/test_diffusion_ref/expected_multi_upscaler.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_multi_upscaler.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.9 MiB |
34
tests/foundationals/latent_diffusion/test_multi_diffusion.py
Normal file
34
tests/foundationals/latent_diffusion/test_multi_diffusion.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion.multi_diffusion import MultiDiffusion, Size
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_latent_tiles() -> None:
|
||||||
|
size = Size(height=128, width=128)
|
||||||
|
tile_size = Size(height=32, width=32)
|
||||||
|
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size)
|
||||||
|
assert len(tiles) == 25
|
||||||
|
|
||||||
|
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=0)
|
||||||
|
assert len(tiles) == 16
|
||||||
|
|
||||||
|
size = Size(height=100, width=200)
|
||||||
|
tile_size = Size(height=32, width=32)
|
||||||
|
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=2)
|
||||||
|
assert len(tiles) == 28
|
||||||
|
|
||||||
|
|
||||||
|
def test_generate_latent_tiles_small_size() -> None:
|
||||||
|
# Test when the size is smaller than the tile size
|
||||||
|
size = Size(height=32, width=32)
|
||||||
|
tile_size = Size(height=64, width=64)
|
||||||
|
tiles = MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size)
|
||||||
|
assert len(tiles) == 1
|
||||||
|
assert Size(tiles[0].bottom - tiles[0].top, tiles[0].right - tiles[0].left) == size
|
||||||
|
|
||||||
|
|
||||||
|
def test_overlap_larger_tile_size() -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
size = Size(height=128, width=128)
|
||||||
|
tile_size = Size(height=32, width=32)
|
||||||
|
MultiDiffusion.generate_latent_tiles(size=size, tile_size=tile_size, min_overlap=32)
|
Loading…
Reference in a new issue