remove useless torch.no_grad() contexts

This commit is contained in:
Pierre Chapuis 2023-09-12 17:55:39 +02:00
parent eea340c6c4
commit cf9efb57c8

View file

@ -401,8 +401,6 @@ def test_diffusion_std_random_init(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -410,7 +408,6 @@ def test_diffusion_std_random_init(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -432,10 +429,7 @@ def test_diffusion_std_random_init_float16(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16 assert clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -443,7 +437,6 @@ def test_diffusion_std_random_init_float16(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -468,8 +461,6 @@ def test_diffusion_std_init_image(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -477,7 +468,6 @@ def test_diffusion_std_init_image(
manual_seed(2) manual_seed(2)
x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step) x = sd15.init_latents((512, 512), cutecat_init, first_step=first_step)
with torch.no_grad():
for step in sd15.steps[first_step:]: for step in sd15.steps[first_step:]:
x = sd15( x = sd15(
x, x,
@ -503,8 +493,6 @@ def test_diffusion_inpainting(
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -513,7 +501,6 @@ def test_diffusion_inpainting(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -540,10 +527,7 @@ def test_diffusion_inpainting_float16(
prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen" prompt = "a large white cat, detailed high-quality professional image, sitting on a chair, in a kitchen"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
assert clip_text_embedding.dtype == torch.float16 assert clip_text_embedding.dtype == torch.float16
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -552,7 +536,6 @@ def test_diffusion_inpainting_float16(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -583,8 +566,6 @@ def test_diffusion_controlnet(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -598,7 +579,6 @@ def test_diffusion_controlnet(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition) controlnet.set_controlnet_condition(cn_condition)
x = sd15( x = sd15(
@ -630,8 +610,6 @@ def test_diffusion_controlnet_structural_copy(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -645,7 +623,6 @@ def test_diffusion_controlnet_structural_copy(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition) controlnet.set_controlnet_condition(cn_condition)
x = sd15( x = sd15(
@ -676,8 +653,6 @@ def test_diffusion_controlnet_float16(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -691,7 +666,6 @@ def test_diffusion_controlnet_float16(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
controlnet.set_controlnet_condition(cn_condition) controlnet.set_controlnet_condition(cn_condition)
x = sd15( x = sd15(
@ -729,8 +703,6 @@ def test_diffusion_controlnet_stack(
prompt = "a cute cat, detailed high-quality professional image" prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality" negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -748,7 +720,6 @@ def test_diffusion_controlnet_stack(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
depth_controlnet.set_controlnet_condition(depth_cn_condition) depth_controlnet.set_controlnet_condition(depth_cn_condition)
canny_controlnet.set_controlnet_condition(canny_cn_condition) canny_controlnet.set_controlnet_condition(canny_cn_condition)
@ -779,8 +750,6 @@ def test_diffusion_lora(
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
prompt = "a cute cat" prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -790,7 +759,6 @@ def test_diffusion_lora(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -819,8 +787,6 @@ def test_diffusion_lora_float16(
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
prompt = "a cute cat" prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -830,7 +796,6 @@ def test_diffusion_lora_float16(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -859,8 +824,6 @@ def test_diffusion_lora_twice(
pytest.skip(allow_module_level=True) pytest.skip(allow_module_level=True)
prompt = "a cute cat" prompt = "a cute cat"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -872,7 +835,6 @@ def test_diffusion_lora_twice(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -893,9 +855,8 @@ def test_diffusion_refonly(
test_device: torch.device, test_device: torch.device,
): ):
sd15 = sd15_ddim sd15 = sd15_ddim
prompt = "Chicken"
with torch.no_grad(): prompt = "Chicken"
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = ReferenceOnlyControlAdapter(sd15.unet).inject() sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
@ -906,7 +867,6 @@ def test_diffusion_refonly(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
noise = torch.randn(2, 4, 64, 64, device=test_device) noise = torch.randn(2, 4, 64, 64, device=test_device)
noised_guide = sd15.scheduler.add_noise(guide, noise, step) noised_guide = sd15.scheduler.add_noise(guide, noise, step)
@ -934,9 +894,8 @@ def test_diffusion_inpainting_refonly(
): ):
sd15 = sd15_inpainting sd15 = sd15_inpainting
n_steps = 30 n_steps = 30
prompt = "" # unconditional
with torch.no_grad(): prompt = "" # unconditional
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sai = ReferenceOnlyControlAdapter(sd15.unet).inject() sai = ReferenceOnlyControlAdapter(sd15.unet).inject()
@ -950,7 +909,6 @@ def test_diffusion_inpainting_refonly(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
noise = torch.randn_like(guide) noise = torch.randn_like(guide)
noised_guide = sd15.scheduler.add_noise(guide, noise, step) noised_guide = sd15.scheduler.add_noise(guide, noise, step)
@ -986,8 +944,6 @@ def test_diffusion_textual_inversion_random_init(
n_steps = 30 n_steps = 30
prompt = "a cute cat on a <gta5-artwork>" prompt = "a cute cat on a <gta5-artwork>"
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(prompt) clip_text_embedding = sd15.compute_clip_text_embedding(prompt)
sd15.set_num_inference_steps(n_steps) sd15.set_num_inference_steps(n_steps)
@ -995,7 +951,6 @@ def test_diffusion_textual_inversion_random_init(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device) x = torch.randn(1, 4, 64, 64, device=test_device)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,
@ -1033,7 +988,6 @@ def test_diffusion_ip_adapter(
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject() ip_adapter.inject()
with torch.no_grad():
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt) clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image)) clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
@ -1052,7 +1006,6 @@ def test_diffusion_ip_adapter(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 64, 64, device=test_device, dtype=torch.float16)
with torch.no_grad():
for step in sd15.steps: for step in sd15.steps:
x = sd15( x = sd15(
x, x,