mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
add tests for FreeU
This commit is contained in:
parent
6eeb01137d
commit
ab0915d052
|
@ -4,6 +4,7 @@ from refiners.foundationals.latent_diffusion.auto_encoder import (
|
||||||
from refiners.foundationals.clip.text_encoder import (
|
from refiners.foundationals.clip.text_encoder import (
|
||||||
CLIPTextEncoderL,
|
CLIPTextEncoderL,
|
||||||
)
|
)
|
||||||
|
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
|
||||||
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
|
from refiners.foundationals.latent_diffusion.schedulers import Scheduler, DPMSolver
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
|
||||||
StableDiffusion_1,
|
StableDiffusion_1,
|
||||||
|
@ -36,4 +37,5 @@ __all__ = [
|
||||||
"Scheduler",
|
"Scheduler",
|
||||||
"CLIPTextEncoderL",
|
"CLIPTextEncoderL",
|
||||||
"LatentDiffusionAutoencoder",
|
"LatentDiffusionAutoencoder",
|
||||||
|
"SDFreeUAdapter",
|
||||||
]
|
]
|
||||||
|
|
|
@ -17,6 +17,7 @@ from refiners.foundationals.latent_diffusion import (
|
||||||
SD1T2IAdapter,
|
SD1T2IAdapter,
|
||||||
SDXLIPAdapter,
|
SDXLIPAdapter,
|
||||||
SDXLT2IAdapter,
|
SDXLT2IAdapter,
|
||||||
|
SDFreeUAdapter,
|
||||||
)
|
)
|
||||||
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
from refiners.foundationals.latent_diffusion.lora import SD1LoraAdapter
|
||||||
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
|
||||||
|
@ -227,6 +228,11 @@ def expected_restart(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
|
return Image.open(fp=ref_path / "expected_restart.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_freeu(ref_path: Path) -> Image.Image:
|
||||||
|
return Image.open(fp=ref_path / "expected_freeu.png").convert(mode="RGB")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
def text_embedding_textual_inversion(test_textual_inversion_path: Path) -> torch.Tensor:
|
||||||
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
return torch.load(test_textual_inversion_path / "gta5-artwork" / "learned_embeds.bin")["<gta5-artwork>"] # type: ignore
|
||||||
|
@ -1604,3 +1610,37 @@ def test_restart(
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_freeu(
|
||||||
|
sd15_std: StableDiffusion_1,
|
||||||
|
expected_freeu: Image.Image,
|
||||||
|
):
|
||||||
|
sd15 = sd15_std
|
||||||
|
n_steps = 50
|
||||||
|
first_step = 1
|
||||||
|
|
||||||
|
prompt = "best quality, high quality cute cat"
|
||||||
|
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
|
|
||||||
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
SDFreeUAdapter(
|
||||||
|
sd15.unet, backbone_scales=[1.2, 1.2, 1.2, 1.4, 1.4, 1.4], skip_scales=[0.9, 0.9, 0.9, 0.2, 0.2, 0.2]
|
||||||
|
).inject()
|
||||||
|
|
||||||
|
manual_seed(9752)
|
||||||
|
x = sd15.init_latents(size=(512, 512), first_step=first_step).to(device=sd15.device, dtype=sd15.dtype)
|
||||||
|
|
||||||
|
for step in sd15.steps[first_step:]:
|
||||||
|
x = sd15(
|
||||||
|
x,
|
||||||
|
step=step,
|
||||||
|
clip_text_embedding=clip_text_embedding,
|
||||||
|
condition_scale=7.5,
|
||||||
|
)
|
||||||
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
|
ensure_similar_images(predicted_image, expected_freeu)
|
||||||
|
|
|
@ -45,6 +45,7 @@ Special cases:
|
||||||
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
- `expected_image_sdxl_ip_adapter_plus_woman.png`
|
||||||
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
- `expected_cutecat_sdxl_ddim_random_init_sag.png`
|
||||||
- `expected_restart.png`
|
- `expected_restart.png`
|
||||||
|
- `expected_freeu.png`
|
||||||
|
|
||||||
## Other images
|
## Other images
|
||||||
|
|
||||||
|
|
BIN
tests/e2e/test_diffusion_ref/expected_freeu.png
Normal file
BIN
tests/e2e/test_diffusion_ref/expected_freeu.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 443 KiB |
41
tests/foundationals/latent_diffusion/test_freeu.py
Normal file
41
tests/foundationals/latent_diffusion/test_freeu.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
|
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter, FreeUResidualConcatenator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=[True, False])
|
||||||
|
def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet | SDXLUNet]:
|
||||||
|
xl: bool = request.param
|
||||||
|
unet = SDXLUNet(in_channels=4) if xl else SD1UNet(in_channels=4)
|
||||||
|
yield unet
|
||||||
|
|
||||||
|
|
||||||
|
def test_freeu_adapter(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
|
freeu = SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9])
|
||||||
|
|
||||||
|
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError) as exc:
|
||||||
|
freeu.eject()
|
||||||
|
assert "could not find" in str(exc.value)
|
||||||
|
|
||||||
|
freeu.inject()
|
||||||
|
assert len(list(unet.walk(FreeUResidualConcatenator))) == 2
|
||||||
|
|
||||||
|
freeu.eject()
|
||||||
|
assert len(list(unet.walk(FreeUResidualConcatenator))) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_freeu_adapter_too_many_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
|
num_blocks = len(unet.UpBlocks)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SDFreeUAdapter(unet, backbone_scales=[1.2] * (num_blocks + 1), skip_scales=[0.9] * (num_blocks + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_freeu_adapter_inconsistent_scales(unet: SD1UNet | SDXLUNet) -> None:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
SDFreeUAdapter(unet, backbone_scales=[1.2, 1.2], skip_scales=[0.9, 0.9, 0.9])
|
Loading…
Reference in a new issue