diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 4fdf434..a9e92df 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -92,6 +92,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_random_init.png").convert("RGB") +@pytest.fixture +def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image: + return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB") + + @pytest.fixture def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image: return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB") @@ -637,6 +642,26 @@ def sd15_std_float16( return sd15 +@pytest.fixture +def sd15_std_bfloat16( + text_encoder_weights: Path, + lda_weights: Path, + unet_weights_std: Path, + test_device: torch.device, +) -> StableDiffusion_1: + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16) + + sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights) + sd15.lda.load_from_safetensors(lda_weights) + sd15.unet.load_from_safetensors(unet_weights_std) + + return sd15 + + @pytest.fixture def sd15_inpainting( text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device @@ -891,6 +916,34 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) +@no_grad() +def test_diffusion_std_random_init_bfloat16( + sd15_std_bfloat16: StableDiffusion_1, + expected_image_std_random_init_bfloat16: Image.Image, +): + sd15 = sd15_std_bfloat16 + + prompt = "a cute cat, detailed high-quality professional image" + negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + + sd15.set_inference_steps(30) + + manual_seed(2) + x = torch.randn(1, 4, 64, 64, device=sd15.device, dtype=sd15.dtype) + + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.latents_to_image(x) + + ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16) + + @no_grad() def test_diffusion_std_sde_random_init( sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device diff --git a/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png b/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png new file mode 100644 index 0000000..c46dd89 Binary files /dev/null and b/tests/e2e/test_diffusion_ref/expected_std_random_init_bfloat16.png differ