mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
remove useless torch.no_grad() contexts
This commit is contained in:
parent
eea340c6c4
commit
cf9efb57c8
|
@ -401,24 +401,21 @@ def test_diffusion_std_random_init(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||||
|
|
||||||
|
@ -432,10 +429,7 @@ def test_diffusion_std_random_init_float16(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
assert clip_text_embedding.dtype == torch.float16
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
@ -443,15 +437,14 @@ def test_diffusion_std_random_init_float16(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -468,24 +461,21 @@ def test_diffusion_std_init_image(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
|
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps[first_step:]:
|
||||||
for step in sd15.steps[first_step:]:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
||||||
|
|
||||||
|
@ -503,9 +493,7 @@ def test_diffusion_inpainting(
|
||||||
|
|
||||||
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask)
|
||||||
|
@ -513,15 +501,14 @@ def test_diffusion_inpainting(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
|
# PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves.
|
||||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
||||||
|
@ -540,10 +527,7 @@ def test_diffusion_inpainting_float16(
|
||||||
|
|
||||||
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
assert clip_text_embedding.dtype == torch.float16
|
assert clip_text_embedding.dtype == torch.float16
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
@ -552,15 +536,14 @@ def test_diffusion_inpainting_float16(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
# PSNR and SSIM values are large because float16 is even worse than float32.
|
# PSNR and SSIM values are large because float16 is even worse than float32.
|
||||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
||||||
|
@ -583,9 +566,7 @@ def test_diffusion_controlnet(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -598,16 +579,15 @@ def test_diffusion_controlnet(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
controlnet.set_controlnet_condition(cn_condition)
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -630,9 +610,7 @@ def test_diffusion_controlnet_structural_copy(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -645,16 +623,15 @@ def test_diffusion_controlnet_structural_copy(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
controlnet.set_controlnet_condition(cn_condition)
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -676,9 +653,7 @@ def test_diffusion_controlnet_float16(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -691,16 +666,15 @@ def test_diffusion_controlnet_float16(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
controlnet.set_controlnet_condition(cn_condition)
|
||||||
controlnet.set_controlnet_condition(cn_condition)
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -729,9 +703,7 @@ def test_diffusion_controlnet_stack(
|
||||||
|
|
||||||
prompt = "a cute cat, detailed high-quality professional image"
|
prompt = "a cute cat, detailed high-quality professional image"
|
||||||
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -748,17 +720,16 @@ def test_diffusion_controlnet_stack(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
depth_controlnet.set_controlnet_condition(depth_cn_condition)
|
||||||
depth_controlnet.set_controlnet_condition(depth_cn_condition)
|
canny_controlnet.set_controlnet_condition(canny_cn_condition)
|
||||||
canny_controlnet.set_controlnet_condition(canny_cn_condition)
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -779,9 +750,7 @@ def test_diffusion_lora(
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
prompt = "a cute cat"
|
prompt = "a cute cat"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -790,15 +759,14 @@ def test_diffusion_lora(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -819,9 +787,7 @@ def test_diffusion_lora_float16(
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
prompt = "a cute cat"
|
prompt = "a cute cat"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -830,15 +796,14 @@ def test_diffusion_lora_float16(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -859,9 +824,7 @@ def test_diffusion_lora_twice(
|
||||||
pytest.skip(allow_module_level=True)
|
pytest.skip(allow_module_level=True)
|
||||||
|
|
||||||
prompt = "a cute cat"
|
prompt = "a cute cat"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
|
@ -872,15 +835,14 @@ def test_diffusion_lora_twice(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -893,10 +855,9 @@ def test_diffusion_refonly(
|
||||||
test_device: torch.device,
|
test_device: torch.device,
|
||||||
):
|
):
|
||||||
sd15 = sd15_ddim
|
sd15 = sd15_ddim
|
||||||
prompt = "Chicken"
|
|
||||||
|
|
||||||
with torch.no_grad():
|
prompt = "Chicken"
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||||
|
|
||||||
|
@ -906,19 +867,18 @@ def test_diffusion_refonly(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
||||||
noise = torch.randn(2, 4, 64, 64, device=test_device)
|
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
sai.set_controlnet_condition(noised_guide)
|
||||||
sai.set_controlnet_condition(noised_guide)
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
|
||||||
torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
@ -934,10 +894,9 @@ def test_diffusion_inpainting_refonly(
|
||||||
):
|
):
|
||||||
sd15 = sd15_inpainting
|
sd15 = sd15_inpainting
|
||||||
n_steps = 30
|
n_steps = 30
|
||||||
prompt = "" # unconditional
|
|
||||||
|
|
||||||
with torch.no_grad():
|
prompt = "" # unconditional
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
|
|
||||||
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
|
||||||
|
|
||||||
|
@ -950,22 +909,21 @@ def test_diffusion_inpainting_refonly(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
noise = torch.randn_like(guide)
|
||||||
noise = torch.randn_like(guide)
|
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
||||||
noised_guide = sd15.scheduler.add_noise(guide, noise, step)
|
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
||||||
# See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support
|
# inpaint variation models")
|
||||||
# inpaint variation models")
|
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
||||||
noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1)
|
|
||||||
|
|
||||||
sai.set_controlnet_condition(noised_guide)
|
sai.set_controlnet_condition(noised_guide)
|
||||||
x = sd15(
|
x = sd15(
|
||||||
x,
|
x,
|
||||||
step=step,
|
step=step,
|
||||||
clip_text_embedding=clip_text_embedding,
|
clip_text_embedding=clip_text_embedding,
|
||||||
condition_scale=7.5,
|
condition_scale=7.5,
|
||||||
)
|
)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
@ -986,24 +944,21 @@ def test_diffusion_textual_inversion_random_init(
|
||||||
n_steps = 30
|
n_steps = 30
|
||||||
|
|
||||||
prompt = "a cute cat on a <gta5-artwork>"
|
prompt = "a cute cat on a <gta5-artwork>"
|
||||||
|
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
||||||
with torch.no_grad():
|
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
|
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device)
|
x = torch.randn(1, 4, 64, 64, device=test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
@ -1033,34 +988,32 @@ def test_diffusion_ip_adapter(
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
with torch.no_grad():
|
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
||||||
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
|
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
||||||
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
|
|
||||||
|
|
||||||
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
|
||||||
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)
|
||||||
|
|
||||||
clip_text_embedding = torch.cat(
|
clip_text_embedding = torch.cat(
|
||||||
(
|
(
|
||||||
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
|
||||||
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
sd15.set_num_inference_steps(n_steps)
|
sd15.set_num_inference_steps(n_steps)
|
||||||
|
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
with torch.no_grad():
|
for step in sd15.steps:
|
||||||
for step in sd15.steps:
|
x = sd15(
|
||||||
x = sd15(
|
x,
|
||||||
x,
|
step=step,
|
||||||
step=step,
|
clip_text_embedding=clip_text_embedding,
|
||||||
clip_text_embedding=clip_text_embedding,
|
condition_scale=7.5,
|
||||||
condition_scale=7.5,
|
)
|
||||||
)
|
predicted_image = sd15.lda.decode_latents(x)
|
||||||
predicted_image = sd15.lda.decode_latents(x)
|
|
||||||
|
|
||||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue