diff --git a/README.md b/README.md index 620a37d..4c52c9f 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ from PIL import Image from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter -from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors +from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors # Load inputs init_image = Image.open("dropy_logo.png") @@ -122,7 +122,7 @@ t2i_adapter.set_scale(0.8) sdxl.set_num_inference_steps(50) sdxl.set_self_attention_guidance(enable=True, scale=0.75) -with torch.no_grad(): +with no_grad(): # Note: default text prompts for IP-Adapter clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" diff --git a/pyproject.toml b/pyproject.toml index 17d2e93..02f3ad7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ build-backend = "hatchling.build" [tool.rye] managed = true dev-dependencies = [ - "pyright == 1.1.333", + "pyright == 1.1.342", "ruff>=0.0.292", "docformatter>=1.7.5", "pytest>=7.4.2", diff --git a/scripts/conversion/convert_diffusers_controlnet.py b/scripts/conversion/convert_diffusers_controlnet.py index cacdfdd..5185193 100644 --- a/scripts/conversion/convert_diffusers_controlnet.py +++ b/scripts/conversion/convert_diffusers_controlnet.py @@ -7,7 +7,7 @@ from diffusers import ControlNetModel # type: ignore from torch import nn from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.utils import no_grad, save_to_safetensors from refiners.foundationals.latent_diffusion import ( DPMSolver, SD1ControlnetAdapter, @@ -20,7 +20,7 @@ class Args(argparse.Namespace): output_path: str | None -@torch.no_grad() +@no_grad() def convert(args: Args) -> dict[str, torch.Tensor]: # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore diff --git a/scripts/conversion/convert_diffusers_lora.py b/scripts/conversion/convert_diffusers_lora.py index 9abffd8..1c37d8d 100644 --- a/scripts/conversion/convert_diffusers_lora.py +++ b/scripts/conversion/convert_diffusers_lora.py @@ -11,7 +11,7 @@ from torch.nn.init import zeros_ import refiners.fluxion.layers as fl from refiners.fluxion.adapters.lora import Lora, LoraAdapter from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import save_to_safetensors +from refiners.fluxion.utils import no_grad, save_to_safetensors from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets @@ -37,7 +37,7 @@ class Args(argparse.Namespace): verbose: bool -@torch.no_grad() +@no_grad() def process(args: Args) -> None: diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` diff --git a/src/refiners/fluxion/layers/sampling.py b/src/refiners/fluxion/layers/sampling.py index d6368e3..69d1412 100644 --- a/src/refiners/fluxion/layers/sampling.py +++ b/src/refiners/fluxion/layers/sampling.py @@ -1,3 +1,5 @@ +from typing import Callable + from torch import Size, Tensor, device as Device, dtype as DType from torch.nn.functional import pad @@ -40,7 +42,8 @@ class Downsample(Chain): ), ) if padding == 0: - self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1)))) + zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1)) + self.insert(0, Lambda(zero_pad)) if register_shape: self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape)) diff --git a/src/refiners/fluxion/model_converter.py b/src/refiners/fluxion/model_converter.py index ef14d02..8e47ebb 100644 --- a/src/refiners/fluxion/model_converter.py +++ b/src/refiners/fluxion/model_converter.py @@ -7,7 +7,7 @@ import torch from torch import Tensor, nn from torch.utils.hooks import RemovableHandle -from refiners.fluxion.utils import norm, save_to_safetensors +from refiners.fluxion.utils import no_grad, norm, save_to_safetensors TORCH_BASIC_LAYERS: list[type[nn.Module]] = [ nn.Conv1d, @@ -512,7 +512,7 @@ class ModelConverter: return True - @torch.no_grad() + @no_grad() def _trace_module_execution_order( self, module: nn.Module, @@ -603,7 +603,7 @@ class ModelConverter: return converted_state_dict - @torch.no_grad() + @no_grad() def _collect_layers_outputs( self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str] ) -> list[tuple[str, Tensor]]: diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 7c4f5e0..deb0d46 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Iterable, Literal, TypeVar +from typing import Any, Iterable, Literal, TypeVar import torch from jaxtyping import Float @@ -7,7 +7,14 @@ from numpy import array, float32 from PIL import Image from safetensors import safe_open as _safe_open # type: ignore from safetensors.torch import save_file as _save_file # type: ignore -from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore +from torch import ( + Tensor, + device as Device, + dtype as DType, + manual_seed as _manual_seed, # type: ignore + no_grad as _no_grad, # type: ignore + norm as _norm, # type: ignore +) from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore T = TypeVar("T") @@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None: _manual_seed(seed) +class no_grad(_no_grad): + def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore + return object.__new__(cls) + + def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor: return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore diff --git a/src/refiners/foundationals/clip/image_encoder.py b/src/refiners/foundationals/clip/image_encoder.py index ed6db3d..270d1be 100644 --- a/src/refiners/foundationals/clip/image_encoder.py +++ b/src/refiners/foundationals/clip/image_encoder.py @@ -1,4 +1,6 @@ -from torch import device as Device, dtype as DType +from typing import Callable + +from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl from refiners.foundationals.clip.common import FeedForward, PositionalEncoder @@ -126,6 +128,7 @@ class CLIPImageEncoder(fl.Chain): self.num_layers = num_layers self.num_attention_heads = num_attention_heads self.feedforward_dim = feedforward_dim + cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :] super().__init__( ViTEmbeddings( image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype @@ -142,7 +145,7 @@ class CLIPImageEncoder(fl.Chain): ) for _ in range(num_layers) ), - fl.Lambda(func=lambda x: x[:, 0, :]), + fl.Lambda(func=cls_token_pooling), fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype), ) diff --git a/src/refiners/foundationals/latent_diffusion/freeu.py b/src/refiners/foundationals/latent_diffusion/freeu.py index 61726bd..3e2580f 100644 --- a/src/refiners/foundationals/latent_diffusion/freeu.py +++ b/src/refiners/foundationals/latent_diffusion/freeu.py @@ -1,5 +1,5 @@ import math -from typing import Any, Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar import torch from torch import Tensor @@ -54,9 +54,10 @@ class FreeUBackboneFeatures(fl.Module): class FreeUSkipFeatures(fl.Chain): def __init__(self, n: int, skip_scale: float) -> None: + apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale) super().__init__( fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]), - fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)), + fl.Lambda(apply_filter), ) diff --git a/src/refiners/foundationals/latent_diffusion/lora.py b/src/refiners/foundationals/latent_diffusion/lora.py index bc041c8..d8820ab 100644 --- a/src/refiners/foundationals/latent_diffusion/lora.py +++ b/src/refiners/foundationals/latent_diffusion/lora.py @@ -122,7 +122,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]): assert metadata is not None, "Invalid safetensors checkpoint: missing metadata" tensors = load_from_safetensors(checkpoint_path, device=target.device) - sub_targets = {} + sub_targets: dict[str, list[LoraTarget]] = {} for model_name in MODELS: if not (v := metadata.get(f"{model_name}_targets", "")): continue diff --git a/src/refiners/foundationals/latent_diffusion/reference_only_control.py b/src/refiners/foundationals/latent_diffusion/reference_only_control.py index bf17bc7..1f0e049 100644 --- a/src/refiners/foundationals/latent_diffusion/reference_only_control.py +++ b/src/refiners/foundationals/latent_diffusion/reference_only_control.py @@ -1,3 +1,5 @@ +from typing import Callable + from torch import Tensor from refiners.fluxion.adapters.adapter import Adapter @@ -45,8 +47,9 @@ class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]): ) with self.setup_adapter(target): + slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1] super().__init__( - Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)), + Parallel(sa_guided, Chain(Lambda(slice_tensor), target)), Lambda(self.compute_averaged_unconditioned_x), ) diff --git a/src/refiners/foundationals/segment_anything/model.py b/src/refiners/foundationals/segment_anything/model.py index f8abfb7..a7da415 100644 --- a/src/refiners/foundationals/segment_anything/model.py +++ b/src/refiners/foundationals/segment_anything/model.py @@ -7,7 +7,7 @@ from PIL import Image from torch import Tensor, device as Device, dtype as DType import refiners.fluxion.layers as fl -from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad +from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder @@ -39,7 +39,7 @@ class SegmentAnything(fl.Module): self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) - @torch.no_grad() + @no_grad() def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: original_size = (image.height, image.width) target_size = self.compute_target_size(original_size) @@ -48,7 +48,7 @@ class SegmentAnything(fl.Module): original_image_size=original_size, ) - @torch.no_grad() + @no_grad() def predict( self, input: Image.Image | ImageEmbedding, diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index dff85fc..f4f8ccf 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -250,7 +250,7 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]): self.timestep_bins[bin_index].append(loss_value) def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None: - log_data = {} + log_data: dict[str, WandbLoggable] = {} for bin_index, losses in self.timestep_bins.items(): if losses: avg_loss = sum(losses) / len(losses) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 87276d8..730bb8a 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, Iterable, TypeVar, cast import numpy as np from loguru import logger -from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack +from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack from torch.autograd import backward from torch.nn import Parameter from torch.optim import Optimizer @@ -26,7 +26,7 @@ from torch.optim.lr_scheduler import ( from torch.utils.data import DataLoader, Dataset from refiners.fluxion import layers as fl -from refiners.fluxion.utils import manual_seed +from refiners.fluxion.utils import manual_seed, no_grad from refiners.training_utils.callback import ( Callback, ClockCallback, diff --git a/tests/e2e/test_diffusion.py b/tests/e2e/test_diffusion.py index 201c1b0..9cf70e8 100644 --- a/tests/e2e/test_diffusion.py +++ b/tests/e2e/test_diffusion.py @@ -6,7 +6,7 @@ import pytest import torch from PIL import Image -from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed +from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.latent_diffusion import ( SD1ControlnetAdapter, @@ -501,7 +501,7 @@ def sdxl_ddim( return sdxl -@torch.no_grad() +@no_grad() def test_diffusion_std_random_init( sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): @@ -529,7 +529,7 @@ def test_diffusion_std_random_init( ensure_similar_images(predicted_image, expected_image_std_random_init) -@torch.no_grad() +@no_grad() def test_diffusion_karras_random_init( sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device ): @@ -554,7 +554,7 @@ def test_diffusion_karras_random_init( ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_std_random_init_float16( sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device ): @@ -583,7 +583,7 @@ def test_diffusion_std_random_init_float16( ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_std_random_init_sag( sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device ): @@ -612,7 +612,7 @@ def test_diffusion_std_random_init_sag( ensure_similar_images(predicted_image, expected_image_std_random_init_sag) -@torch.no_grad() +@no_grad() def test_diffusion_std_init_image( sd15_std: StableDiffusion_1, cutecat_init: Image.Image, @@ -643,7 +643,7 @@ def test_diffusion_std_init_image( ensure_similar_images(predicted_image, expected_image_std_init_image) -@torch.no_grad() +@no_grad() def test_rectangular_init_latents( sd15_std: StableDiffusion_1, cutecat_init: Image.Image, @@ -658,7 +658,7 @@ def test_rectangular_init_latents( assert sd15.lda.decode_latents(x).size == (width, height) -@torch.no_grad() +@no_grad() def test_diffusion_inpainting( sd15_inpainting: StableDiffusion_1_Inpainting, kitchen_dog: Image.Image, @@ -692,7 +692,7 @@ def test_diffusion_inpainting( ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) -@torch.no_grad() +@no_grad() def test_diffusion_inpainting_float16( sd15_inpainting_float16: StableDiffusion_1_Inpainting, kitchen_dog: Image.Image, @@ -727,7 +727,7 @@ def test_diffusion_inpainting_float16( ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet( sd15_std: StableDiffusion_1, controlnet_data: tuple[str, Image.Image, Image.Image, Path], @@ -770,7 +770,7 @@ def test_diffusion_controlnet( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet_structural_copy( sd15_std: StableDiffusion_1, controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], @@ -814,7 +814,7 @@ def test_diffusion_controlnet_structural_copy( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet_float16( sd15_std_float16: StableDiffusion_1, controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], @@ -857,7 +857,7 @@ def test_diffusion_controlnet_float16( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_controlnet_stack( sd15_std: StableDiffusion_1, controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], @@ -912,7 +912,7 @@ def test_diffusion_controlnet_stack( ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_lora( sd15_std: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, Path], @@ -949,7 +949,7 @@ def test_diffusion_lora( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_lora_float16( sd15_std_float16: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, Path], @@ -986,7 +986,7 @@ def test_diffusion_lora_float16( ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_lora_twice( sd15_std: StableDiffusion_1, lora_data_pokemon: tuple[Image.Image, Path], @@ -1025,7 +1025,7 @@ def test_diffusion_lora_twice( ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_refonly( sd15_ddim: StableDiffusion_1, condition_image_refonly: Image.Image, @@ -1061,7 +1061,7 @@ def test_diffusion_refonly( ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99) -@torch.no_grad() +@no_grad() def test_diffusion_inpainting_refonly( sd15_inpainting: StableDiffusion_1_Inpainting, scene_image_inpainting_refonly: Image.Image, @@ -1106,7 +1106,7 @@ def test_diffusion_inpainting_refonly( ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) -@torch.no_grad() +@no_grad() def test_diffusion_textual_inversion_random_init( sd15_std: StableDiffusion_1, expected_image_textual_inversion_random_init: Image.Image, @@ -1141,7 +1141,7 @@ def test_diffusion_textual_inversion_random_init( ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_ip_adapter( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_weights: Path, @@ -1196,7 +1196,7 @@ def test_diffusion_ip_adapter( ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) -@torch.no_grad() +@no_grad() def test_diffusion_sdxl_ip_adapter( sdxl_ddim: StableDiffusion_XL, sdxl_ip_adapter_weights: Path, @@ -1215,7 +1215,7 @@ def test_diffusion_sdxl_ip_adapter( ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.inject() - with torch.no_grad(): + with no_grad(): clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( text=prompt, negative_text=negative_prompt ) @@ -1236,7 +1236,7 @@ def test_diffusion_sdxl_ip_adapter( manual_seed(2) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) - with torch.no_grad(): + with no_grad(): for step in sdxl.steps: x = sdxl( x, @@ -1254,7 +1254,7 @@ def test_diffusion_sdxl_ip_adapter( ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) -@torch.no_grad() +@no_grad() def test_diffusion_ip_adapter_controlnet( sd15_ddim: StableDiffusion_1, ip_adapter_weights: Path, @@ -1320,7 +1320,7 @@ def test_diffusion_ip_adapter_controlnet( ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) -@torch.no_grad() +@no_grad() def test_diffusion_ip_adapter_plus( sd15_ddim_lda_ft_mse: StableDiffusion_1, ip_adapter_plus_weights: Path, @@ -1371,7 +1371,7 @@ def test_diffusion_ip_adapter_plus( ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_diffusion_sdxl_ip_adapter_plus( sdxl_ddim: StableDiffusion_XL, sdxl_ip_adapter_plus_weights: Path, @@ -1427,7 +1427,7 @@ def test_diffusion_sdxl_ip_adapter_plus( ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) -@torch.no_grad() +@no_grad() def test_sdxl_random_init( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device ) -> None: @@ -1462,7 +1462,7 @@ def test_sdxl_random_init( ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_sdxl_random_init_sag( sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device ) -> None: @@ -1498,7 +1498,7 @@ def test_sdxl_random_init_sag( ensure_similar_images(img_1=predicted_image, img_2=expected_image) -@torch.no_grad() +@no_grad() def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: manual_seed(seed=2) sd = sd15_ddim @@ -1529,7 +1529,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_t2i_adapter_depth( sd15_std: StableDiffusion_1, t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path], @@ -1570,7 +1570,7 @@ def test_t2i_adapter_depth( ensure_similar_images(predicted_image, expected_image) -@torch.no_grad() +@no_grad() def test_t2i_adapter_xl_canny( sdxl_ddim: StableDiffusion_XL, t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], @@ -1619,7 +1619,7 @@ def test_t2i_adapter_xl_canny( ensure_similar_images(predicted_image, expected_image) -@torch.no_grad() +@no_grad() def test_restart( sd15_ddim: StableDiffusion_1, expected_restart: Image.Image, @@ -1659,7 +1659,7 @@ def test_restart( ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) -@torch.no_grad() +@no_grad() def test_freeu( sd15_std: StableDiffusion_1, expected_freeu: Image.Image, diff --git a/tests/e2e/test_preprocessors.py b/tests/e2e/test_preprocessors.py index 69131b2..4492638 100644 --- a/tests/e2e/test_preprocessors.py +++ b/tests/e2e/test_preprocessors.py @@ -5,7 +5,7 @@ import pytest import torch from PIL import Image -from refiners.fluxion.utils import image_to_tensor, tensor_to_image +from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings from tests.utils import ensure_similar_images @@ -41,7 +41,7 @@ def informative_drawings_model(informative_drawings_weights: Path, test_device: return model -@torch.no_grad() +@no_grad() def test_preprocessor_informative_drawing( informative_drawings_model: InformativeDrawings, cutecat_init: Image.Image, diff --git a/tests/fluxion/layers/test_converter.py b/tests/fluxion/layers/test_converter.py index 8a33188..356cdb1 100644 --- a/tests/fluxion/layers/test_converter.py +++ b/tests/fluxion/layers/test_converter.py @@ -1,3 +1,4 @@ +from typing import Any, Callable from warnings import warn import pytest @@ -60,8 +61,9 @@ def test_converter_multiple_tensors(test_device: torch.device) -> None: def test_converter_no_parent_device_or_dtype() -> None: + identity: Callable[[Any], Any] = lambda x: x chain = fl.Chain( - fl.Lambda(func=(lambda x: x)), + fl.Lambda(func=identity), fl.Converter(set_device=True, set_dtype=False), ) diff --git a/tests/fluxion/test_utils.py b/tests/fluxion/test_utils.py index 8811c47..8837550 100644 --- a/tests/fluxion/test_utils.py +++ b/tests/fluxion/test_utils.py @@ -7,7 +7,7 @@ from PIL import Image from torch import device as Device, dtype as DType from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore -from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image +from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image @dataclass @@ -62,3 +62,18 @@ def test_tensor_to_image() -> None: assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" + + +def test_no_grad() -> None: + x = torch.randn(1, 1, requires_grad=True) + + with torch.no_grad(): + y = x + 1 + assert not y.requires_grad + + with no_grad(): + z = x + 1 + assert not z.requires_grad + + w = x + 1 + assert w.requires_grad diff --git a/tests/foundationals/clip/test_concepts.py b/tests/foundationals/clip/test_concepts.py index ed86561..9c4ed74 100644 --- a/tests/foundationals/clip/test_concepts.py +++ b/tests/foundationals/clip/test_concepts.py @@ -7,7 +7,7 @@ import transformers # type: ignore from diffusers import StableDiffusionPipeline # type: ignore import refiners.fluxion.layers as fl -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer @@ -124,7 +124,7 @@ def test_encoder( our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) - with torch.no_grad(): + with no_grad(): ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0] our_embeddings = our_encoder_with_new_concepts(prompt) diff --git a/tests/foundationals/clip/test_image_encoder.py b/tests/foundationals/clip/test_image_encoder.py index 3aac668..ff990bd 100644 --- a/tests/foundationals/clip/test_image_encoder.py +++ b/tests/foundationals/clip/test_image_encoder.py @@ -5,7 +5,7 @@ import pytest import torch from transformers import CLIPVisionModelWithProjection # type: ignore -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH @@ -44,7 +44,7 @@ def test_encoder( ): x = torch.randn(1, 3, 224, 224).to(test_device) - with torch.no_grad(): + with no_grad(): ref_embeddings = ref_encoder(x).image_embeds our_embeddings = our_encoder(x) diff --git a/tests/foundationals/clip/test_text_encoder.py b/tests/foundationals/clip/test_text_encoder.py index 0e108b7..f1b6f07 100644 --- a/tests/foundationals/clip/test_text_encoder.py +++ b/tests/foundationals/clip/test_text_encoder.py @@ -5,7 +5,7 @@ import pytest import torch import transformers # type: ignore -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.tokenizer import CLIPTokenizer @@ -89,7 +89,7 @@ def test_encoder( our_tokens = tokenizer(prompt) assert torch.equal(our_tokens, ref_tokens) - with torch.no_grad(): + with no_grad(): ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] our_embeddings = our_encoder(prompt) diff --git a/tests/foundationals/dinov2/test_dinov2.py b/tests/foundationals/dinov2/test_dinov2.py index 9b5f40e..7bcf818 100644 --- a/tests/foundationals/dinov2/test_dinov2.py +++ b/tests/foundationals/dinov2/test_dinov2.py @@ -7,7 +7,7 @@ import torch from transformers import AutoModel # type: ignore from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore -from refiners.fluxion.utils import load_from_safetensors, manual_seed +from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad from refiners.foundationals.dinov2 import ( DINOv2_base, DINOv2_base_reg, @@ -124,7 +124,7 @@ def test_encoder( x = torch.randn(1, 3, 518, 518).to(test_device) - with torch.no_grad(): + with no_grad(): ref_features = ref_backbone(x).last_hidden_state our_features = our_backbone(x) diff --git a/tests/foundationals/latent_diffusion/test_auto_encoder.py b/tests/foundationals/latent_diffusion/test_auto_encoder.py index 2ddca24..462c407 100644 --- a/tests/foundationals/latent_diffusion/test_auto_encoder.py +++ b/tests/foundationals/latent_diffusion/test_auto_encoder.py @@ -6,7 +6,7 @@ import torch from PIL import Image from tests.utils import ensure_similar_images -from refiners.fluxion.utils import load_from_safetensors +from refiners.fluxion.utils import load_from_safetensors, no_grad from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder @@ -38,7 +38,7 @@ def sample_image(ref_path: Path) -> Image.Image: return img -@torch.no_grad() +@no_grad() def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): encoded = encoder.encode_image(sample_image) decoded = encoder.decode_latents(encoded) diff --git a/tests/foundationals/latent_diffusion/test_controlnet.py b/tests/foundationals/latent_diffusion/test_controlnet.py index 36f3b04..4bfc5e6 100644 --- a/tests/foundationals/latent_diffusion/test_controlnet.py +++ b/tests/foundationals/latent_diffusion/test_controlnet.py @@ -1,10 +1,10 @@ from typing import Iterator import pytest -import torch import refiners.fluxion.layers as fl from refiners.fluxion.adapters.adapter import lookup_top_adapter +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet @@ -18,7 +18,7 @@ def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]: yield unet -@torch.no_grad() +@no_grad() def test_single_controlnet(unet: SD1UNet) -> None: original_parent = unet.parent cn = SD1ControlnetAdapter(unet, name="cn") @@ -43,7 +43,7 @@ def test_single_controlnet(unet: SD1UNet) -> None: assert len(list(unet.walk(Controlnet))) == 0 -@torch.no_grad() +@no_grad() def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: original_parent = unet.parent cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() @@ -71,7 +71,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: assert len(list(unet.walk(Controlnet))) == 0 -@torch.no_grad() +@no_grad() def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: original_parent = unet.parent cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() @@ -86,7 +86,7 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: assert len(list(unet.walk(Controlnet))) == 0 -@torch.no_grad() +@no_grad() def test_two_controlnets_same_name(unet: SD1UNet) -> None: SD1ControlnetAdapter(unet, name="cnx").inject() cn2 = SD1ControlnetAdapter(unet, name="cnx") diff --git a/tests/foundationals/latent_diffusion/test_freeu.py b/tests/foundationals/latent_diffusion/test_freeu.py index 6b7001b..3e4553e 100644 --- a/tests/foundationals/latent_diffusion/test_freeu.py +++ b/tests/foundationals/latent_diffusion/test_freeu.py @@ -4,6 +4,7 @@ import pytest import torch from refiners.fluxion import manual_seed +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter @@ -52,14 +53,14 @@ def test_freeu_identity_scales() -> None: unet = SD1UNet(in_channels=4) unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_1 = unet(x.clone()) freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0]) freeu.inject() - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_2 = unet(x.clone()) diff --git a/tests/foundationals/latent_diffusion/test_reference_only_control.py b/tests/foundationals/latent_diffusion/test_reference_only_control.py index 68833b3..d0ed8a3 100644 --- a/tests/foundationals/latent_diffusion/test_reference_only_control.py +++ b/tests/foundationals/latent_diffusion/test_reference_only_control.py @@ -1,6 +1,6 @@ import pytest -import torch +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock from refiners.foundationals.latent_diffusion.reference_only_control import ( @@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import ( ) -@torch.no_grad() +@no_grad() def test_refonly_inject_eject() -> None: unet = SD1UNet(in_channels=9) adapter = ReferenceOnlyControlAdapter(unet) diff --git a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py index cb51253..9435b89 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_double_encoder.py @@ -7,7 +7,7 @@ import torch from torch import Tensor import refiners.fluxion.layers as fl -from refiners.fluxion.utils import manual_seed +from refiners.fluxion.utils import manual_seed, no_grad from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder @@ -65,7 +65,7 @@ def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder: return double_text_encoder -@torch.no_grad() +@no_grad() def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None: manual_seed(seed=0) prompt = "A photo of a pizza." diff --git a/tests/foundationals/latent_diffusion/test_sdxl_unet.py b/tests/foundationals/latent_diffusion/test_sdxl_unet.py index 95b031b..c3d0f10 100644 --- a/tests/foundationals/latent_diffusion/test_sdxl_unet.py +++ b/tests/foundationals/latent_diffusion/test_sdxl_unet.py @@ -6,7 +6,7 @@ import pytest import torch from refiners.fluxion.model_converter import ConversionStage, ModelConverter -from refiners.fluxion.utils import manual_seed +from refiners.fluxion.utils import manual_seed, no_grad from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet @@ -37,7 +37,7 @@ def refiners_sdxl_unet() -> SDXLUNet: return unet -@torch.no_grad() +@no_grad() def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None: source = diffusers_sdxl_unet target = refiners_sdxl_unet diff --git a/tests/foundationals/latent_diffusion/test_unet.py b/tests/foundationals/latent_diffusion/test_unet.py index 4fe09e3..210eca7 100644 --- a/tests/foundationals/latent_diffusion/test_unet.py +++ b/tests/foundationals/latent_diffusion/test_unet.py @@ -1,6 +1,7 @@ import torch from refiners.fluxion import manual_seed +from refiners.fluxion.utils import no_grad from refiners.foundationals.latent_diffusion import SD1UNet @@ -13,11 +14,11 @@ def test_unet_context_flush(): unet = SD1UNet(in_channels=4) unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_1 = unet(x.clone()) - with torch.no_grad(): + with no_grad(): unet.set_timestep(timestep=timestep) y_2 = unet(x.clone()) diff --git a/tests/foundationals/segment_anything/test_sam.py b/tests/foundationals/segment_anything/test_sam.py index 0c5fbf9..de4147e 100644 --- a/tests/foundationals/segment_anything/test_sam.py +++ b/tests/foundationals/segment_anything/test_sam.py @@ -18,7 +18,7 @@ from torch import Tensor from refiners.fluxion import manual_seed from refiners.fluxion.model_converter import ModelConverter -from refiners.fluxion.utils import image_to_tensor +from refiners.fluxion.utils import image_to_tensor, no_grad from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer @@ -98,7 +98,7 @@ def truck(ref_path: Path) -> Image.Image: return Image.open(ref_path / "truck.jpg").convert("RGB") -@torch.no_grad() +@no_grad() def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: manual_seed(seed=0) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) @@ -124,7 +124,7 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: assert torch.equal(input=y_1, other=y_2) -@torch.no_grad() +@no_grad() def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device) y_1 = facebook_sam_h.image_encoder(image_tensor) @@ -133,7 +133,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru assert torch.allclose(input=y_1, other=y_2, atol=1e-4) -@torch.no_grad() +@no_grad() def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.point_encoder @@ -144,7 +144,7 @@ def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) -@torch.no_grad() +@no_grad() def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.mask_encoder @@ -155,7 +155,7 @@ def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) -@torch.no_grad() +@no_grad() def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None: facebook_prompt_encoder = facebook_sam_h.prompt_encoder refiners_prompt_encoder = sam_h.point_encoder @@ -174,7 +174,7 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe) -@torch.no_grad() +@no_grad() def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) @@ -223,7 +223,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: assert torch.equal(input=y_1, other=y_2) -@torch.no_grad() +@no_grad() def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: manual_seed(seed=0) facebook_mask_decoder = facebook_sam_h.mask_decoder