From cf9efb57c8bfe57b869b4763194cb82cc277a604 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Tue, 12 Sep 2023 17:55:39 +0200 Subject: [PATCH] remove useless torch.no_grad() contexts --- tests/e2e/test_diffusion.py | 385 ++++++++++++++++-------------------- 1 file changed, 169 insertions(+), 216 deletions(-) diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 48f38bf..1dfecae 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -401,24 +401,21 @@ def test_diffusion_std_random_init( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_std_random_init) @@ -432,10 +429,7 @@ def test_diffusion_std_random_init_float16( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) assert clip_text_embedding.dtype == torch.float16 sd15.set_num_inference_steps(n_steps) @@ -443,15 +437,14 @@ def test_diffusion_std_random_init_float16( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) @@ -468,24 +461,21 @@ def test_diffusion_std_init_image( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) manual_seed(2) x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step) - with torch.no_grad(): - for step in sd15.steps[first_step:]: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps[first_step:]: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_std_init_image) @@ -503,9 +493,7 @@ def test_diffusion_inpainting( prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) sd15.set_inpainting_conditions(kitchen_dog, kitchen_dog_mask) @@ -513,15 +501,14 @@ def test_diffusion_inpainting( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) # PSNR and SSIM values are large because with float32 we get large differences even v.s. ourselves. ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) @@ -540,10 +527,7 @@ def test_diffusion_inpainting_float16( prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) assert clip_text_embedding.dtype == torch.float16 sd15.set_num_inference_steps(n_steps) @@ -552,15 +536,14 @@ def test_diffusion_inpainting_float16( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) # PSNR and SSIM values are large because float16 is even worse than float32. ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) @@ -583,9 +566,7 @@ def test_diffusion_controlnet( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) @@ -598,16 +579,15 @@ def test_diffusion_controlnet( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - controlnet.set_controlnet_condition(cn_condition) - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + controlnet.set_controlnet_condition(cn_condition) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -630,9 +610,7 @@ def test_diffusion_controlnet_structural_copy( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) @@ -645,16 +623,15 @@ def test_diffusion_controlnet_structural_copy( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - controlnet.set_controlnet_condition(cn_condition) - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + controlnet.set_controlnet_condition(cn_condition) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -676,9 +653,7 @@ def test_diffusion_controlnet_float16( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) @@ -691,16 +666,15 @@ def test_diffusion_controlnet_float16( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) - with torch.no_grad(): - for step in sd15.steps: - controlnet.set_controlnet_condition(cn_condition) - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + controlnet.set_controlnet_condition(cn_condition) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -729,9 +703,7 @@ def test_diffusion_controlnet_stack( prompt = "a cute cat, detailed high-quality professional image" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) sd15.set_num_inference_steps(n_steps) @@ -748,17 +720,16 @@ def test_diffusion_controlnet_stack( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - depth_controlnet.set_controlnet_condition(depth_cn_condition) - canny_controlnet.set_controlnet_condition(canny_cn_condition) - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + depth_controlnet.set_controlnet_condition(depth_cn_condition) + canny_controlnet.set_controlnet_condition(canny_cn_condition) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) @@ -779,9 +750,7 @@ def test_diffusion_lora( pytest.skip(allow_module_level=True) prompt = "a cute cat" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_num_inference_steps(n_steps) @@ -790,15 +759,14 @@ def test_diffusion_lora( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -819,9 +787,7 @@ def test_diffusion_lora_float16( pytest.skip(allow_module_level=True) prompt = "a cute cat" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_num_inference_steps(n_steps) @@ -830,15 +796,14 @@ def test_diffusion_lora_float16( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98) @@ -859,9 +824,7 @@ def test_diffusion_lora_twice( pytest.skip(allow_module_level=True) prompt = "a cute cat" - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_num_inference_steps(n_steps) @@ -872,15 +835,14 @@ def test_diffusion_lora_twice( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) @@ -893,10 +855,9 @@ def test_diffusion_refonly( test_device: torch.device, ): sd15 = sd15_ddim - prompt = "Chicken" - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + prompt = "Chicken" + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sai = ReferenceOnlyControlAdapter(sd15.unet).inject() @@ -906,19 +867,18 @@ def test_diffusion_refonly( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - noise = torch.randn(2, 4, 64, 64, device=test_device) - noised_guide = sd15.scheduler.add_noise(guide, noise, step) - sai.set_controlnet_condition(noised_guide) - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + noise = torch.randn(2, 4, 64, 64, device=test_device) + noised_guide = sd15.scheduler.add_noise(guide, noise, step) + sai.set_controlnet_condition(noised_guide) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + torch.randn(2, 4, 64, 64, device=test_device) # for SD Web UI reproductibility only + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99) @@ -934,10 +894,9 @@ def test_diffusion_inpainting_refonly( ): sd15 = sd15_inpainting n_steps = 30 - prompt = "" # unconditional - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + prompt = "" # unconditional + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sai = ReferenceOnlyControlAdapter(sd15.unet).inject() @@ -950,22 +909,21 @@ def test_diffusion_inpainting_refonly( manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - noise = torch.randn_like(guide) - noised_guide = sd15.scheduler.add_noise(guide, noise, step) - # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support - # inpaint variation models") - noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) + for step in sd15.steps: + noise = torch.randn_like(guide) + noised_guide = sd15.scheduler.add_noise(guide, noise, step) + # See https://github.com/Mikubill/sd-webui-controlnet/pull/1275 ("1.1.170 reference-only begin to support + # inpaint variation models") + noised_guide = torch.cat([noised_guide, torch.zeros_like(noised_guide)[:, 0:1, :, :], guide], dim=1) - sai.set_controlnet_condition(noised_guide) - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + sai.set_controlnet_condition(noised_guide) + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) @@ -986,24 +944,21 @@ def test_diffusion_textual_inversion_random_init( n_steps = 30 prompt = "a cute cat on a " - - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(prompt) + clip_text_embedding = sd15.compute_clip_text_embedding(prompt) sd15.set_num_inference_steps(n_steps) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) @@ -1033,34 +988,32 @@ def test_diffusion_ip_adapter( ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() - with torch.no_grad(): - clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) - clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) + clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) + clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) - negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) - negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) + negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2) + negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2) - clip_text_embedding = torch.cat( - ( - torch.cat([negative_text_embedding, negative_image_embedding], dim=1), - torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), - ) + clip_text_embedding = torch.cat( + ( + torch.cat([negative_text_embedding, negative_image_embedding], dim=1), + torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1), ) + ) sd15.set_num_inference_steps(n_steps) manual_seed(2) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) - with torch.no_grad(): - for step in sd15.steps: - x = sd15( - x, - step=step, - clip_text_embedding=clip_text_embedding, - condition_scale=7.5, - ) - predicted_image = sd15.lda.decode_latents(x) + for step in sd15.steps: + x = sd15( + x, + step=step, + clip_text_embedding=clip_text_embedding, + condition_scale=7.5, + ) + predicted_image = sd15.lda.decode_latents(x) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)