mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add test_diffusion_std_random_init_bfloat16 e2e test
This commit is contained in:
parent
12622ad114
commit
4360aa046f
|
@ -92,6 +92,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
|
|||
return _img_open(ref_path / "expected_std_random_init.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_random_init_bfloat16(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_std_random_init_bfloat16.png").convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_image_std_sde_random_init(ref_path: Path) -> Image.Image:
|
||||
return _img_open(ref_path / "expected_std_sde_random_init.png").convert("RGB")
|
||||
|
@ -637,6 +642,26 @@ def sd15_std_float16(
|
|||
return sd15
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_std_bfloat16(
|
||||
text_encoder_weights: Path,
|
||||
lda_weights: Path,
|
||||
unet_weights_std: Path,
|
||||
test_device: torch.device,
|
||||
) -> StableDiffusion_1:
|
||||
if test_device.type == "cpu":
|
||||
warn("not running on CPU, skipping")
|
||||
pytest.skip()
|
||||
|
||||
sd15 = StableDiffusion_1(device=test_device, dtype=torch.bfloat16)
|
||||
|
||||
sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
|
||||
sd15.lda.load_from_safetensors(lda_weights)
|
||||
sd15.unet.load_from_safetensors(unet_weights_std)
|
||||
|
||||
return sd15
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sd15_inpainting(
|
||||
text_encoder_weights: Path, lda_weights: Path, unet_weights_inpainting: Path, test_device: torch.device
|
||||
|
@ -891,6 +916,34 @@ def test_diffusion_std_random_init(
|
|||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_std_random_init_bfloat16(
|
||||
sd15_std_bfloat16: StableDiffusion_1,
|
||||
expected_image_std_random_init_bfloat16: Image.Image,
|
||||
):
|
||||
sd15 = sd15_std_bfloat16
|
||||
|
||||
prompt = "a cute cat, detailed high-quality professional image"
|
||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||
|
||||
sd15.set_inference_steps(30)
|
||||
|
||||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 64, 64, device=sd15.device, dtype=sd15.dtype)
|
||||
|
||||
for step in sd15.steps:
|
||||
x = sd15(
|
||||
x,
|
||||
step=step,
|
||||
clip_text_embedding=clip_text_embedding,
|
||||
condition_scale=7.5,
|
||||
)
|
||||
predicted_image = sd15.lda.latents_to_image(x)
|
||||
|
||||
ensure_similar_images(predicted_image, expected_image_std_random_init_bfloat16)
|
||||
|
||||
|
||||
@no_grad()
|
||||
def test_diffusion_std_sde_random_init(
|
||||
sd15_std_sde: StableDiffusion_1, expected_image_std_sde_random_init: Image.Image, test_device: torch.device
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 490 KiB |
Loading…
Reference in a new issue