add e2e test for sd15 with karras noise schedule

This commit is contained in:
limiteinductive 2023-12-03 18:07:42 +01:00 committed by Benjamin Trom
parent 6f110ee2b2
commit 90db6ef59d
2 changed files with 49 additions and 0 deletions

View file

@ -25,6 +25,7 @@ from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.schedulers import DDIM from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
@ -66,6 +67,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")
@pytest.fixture
def expected_karras_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB")
@pytest.fixture @pytest.fixture
def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image: def expected_image_std_random_init_sag(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB") return Image.open(ref_path / "expected_std_random_init_sag.png").convert("RGB")
@ -416,6 +422,24 @@ def sd15_ddim(
return sd15 return sd15
@pytest.fixture
def sd15_ddim_karras(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()
ddim_scheduler = DDIM(num_inference_steps=20, noise_schedule=NoiseSchedule.KARRAS)
sd15 = StableDiffusion_1(scheduler=ddim_scheduler, device=test_device)
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)
return sd15
@pytest.fixture @pytest.fixture
def sd15_ddim_lda_ft_mse( def sd15_ddim_lda_ft_mse(
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
@ -507,6 +531,31 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init) ensure_similar_images(predicted_image, expected_image_std_random_init)
@torch.no_grad()
def test_diffusion_karras_random_init(
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
):
sd15 = sd15_ddim_karras
prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @torch.no_grad()
def test_diffusion_std_random_init_float16( def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device

Binary file not shown.

After

Width:  |  Height:  |  Size: 512 KiB