upgrade pyright to 1.1.342 ; improve no_grad typing

This commit is contained in:
limiteinductive 2023-12-29 10:59:51 +01:00 committed by Benjamin Trom
parent 7b14b4d981
commit 20c229903f
30 changed files with 136 additions and 95 deletions

View file

@ -92,7 +92,7 @@ from PIL import Image
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors
from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors
# Load inputs
init_image = Image.open("dropy_logo.png")
@ -122,7 +122,7 @@ t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50)
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
with torch.no_grad():
with no_grad():
# Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"

View file

@ -54,7 +54,7 @@ build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"pyright == 1.1.333",
"pyright == 1.1.342",
"ruff>=0.0.292",
"docformatter>=1.7.5",
"pytest>=7.4.2",

View file

@ -7,7 +7,7 @@ from diffusers import ControlNetModel # type: ignore
from torch import nn
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import (
DPMSolver,
SD1ControlnetAdapter,
@ -20,7 +20,7 @@ class Args(argparse.Namespace):
output_path: str | None
@torch.no_grad()
@no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore

View file

@ -11,7 +11,7 @@ from torch.nn.init import zeros_
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors
from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
@ -37,7 +37,7 @@ class Args(argparse.Namespace):
verbose: bool
@torch.no_grad()
@no_grad()
def process(args: Args) -> None:
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`

View file

@ -1,3 +1,5 @@
from typing import Callable
from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn.functional import pad
@ -40,7 +42,8 @@ class Downsample(Chain):
),
)
if padding == 0:
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad))
if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))

View file

@ -7,7 +7,7 @@ import torch
from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle
from refiners.fluxion.utils import norm, save_to_safetensors
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d,
@ -512,7 +512,7 @@ class ModelConverter:
return True
@torch.no_grad()
@no_grad()
def _trace_module_execution_order(
self,
module: nn.Module,
@ -603,7 +603,7 @@ class ModelConverter:
return converted_state_dict
@torch.no_grad()
@no_grad()
def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]:

View file

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Iterable, Literal, TypeVar
from typing import Any, Iterable, Literal, TypeVar
import torch
from jaxtyping import Float
@ -7,7 +7,14 @@ from numpy import array, float32
from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore
from torch import (
Tensor,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
T = TypeVar("T")
@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None:
_manual_seed(seed)
class no_grad(_no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore

View file

@ -1,4 +1,6 @@
from torch import device as Device, dtype as DType
from typing import Callable
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
@ -126,6 +128,7 @@ class CLIPImageEncoder(fl.Chain):
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
super().__init__(
ViTEmbeddings(
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
@ -142,7 +145,7 @@ class CLIPImageEncoder(fl.Chain):
)
for _ in range(num_layers)
),
fl.Lambda(func=lambda x: x[:, 0, :]),
fl.Lambda(func=cls_token_pooling),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
)

View file

@ -1,5 +1,5 @@
import math
from typing import Any, Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar
import torch
from torch import Tensor
@ -54,9 +54,10 @@ class FreeUBackboneFeatures(fl.Module):
class FreeUSkipFeatures(fl.Chain):
def __init__(self, n: int, skip_scale: float) -> None:
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
super().__init__(
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)),
fl.Lambda(apply_filter),
)

View file

@ -122,7 +122,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device)
sub_targets = {}
sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")):
continue

View file

@ -1,3 +1,5 @@
from typing import Callable
from torch import Tensor
from refiners.fluxion.adapters.adapter import Adapter
@ -45,8 +47,9 @@ class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
)
with self.setup_adapter(target):
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
super().__init__(
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)),
Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
Lambda(self.compute_averaged_unconditioned_x),
)

View file

@ -7,7 +7,7 @@ from PIL import Image
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -39,7 +39,7 @@ class SegmentAnything(fl.Module):
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
@torch.no_grad()
@no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size)
@ -48,7 +48,7 @@ class SegmentAnything(fl.Module):
original_image_size=original_size,
)
@torch.no_grad()
@no_grad()
def predict(
self,
input: Image.Image | ImageEmbedding,

View file

@ -250,7 +250,7 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
self.timestep_bins[bin_index].append(loss_value)
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
log_data = {}
log_data: dict[str, WandbLoggable] = {}
for bin_index, losses in self.timestep_bins.items():
if losses:
avg_loss = sum(losses) / len(losses)

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
import numpy as np
from loguru import logger
from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
from torch.autograd import backward
from torch.nn import Parameter
from torch.optim import Optimizer
@ -26,7 +26,7 @@ from torch.optim.lr_scheduler import (
from torch.utils.data import DataLoader, Dataset
from refiners.fluxion import layers as fl
from refiners.fluxion.utils import manual_seed
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.training_utils.callback import (
Callback,
ClockCallback,

View file

@ -6,7 +6,7 @@ import pytest
import torch
from PIL import Image
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter,
@ -501,7 +501,7 @@ def sdxl_ddim(
return sdxl
@torch.no_grad()
@no_grad()
def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
@ -529,7 +529,7 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)
@torch.no_grad()
@no_grad()
def test_diffusion_karras_random_init(
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
):
@ -554,7 +554,7 @@ def test_diffusion_karras_random_init(
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
):
@ -583,7 +583,7 @@ def test_diffusion_std_random_init_float16(
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
):
@ -612,7 +612,7 @@ def test_diffusion_std_random_init_sag(
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@torch.no_grad()
@no_grad()
def test_diffusion_std_init_image(
sd15_std: StableDiffusion_1,
cutecat_init: Image.Image,
@ -643,7 +643,7 @@ def test_diffusion_std_init_image(
ensure_similar_images(predicted_image, expected_image_std_init_image)
@torch.no_grad()
@no_grad()
def test_rectangular_init_latents(
sd15_std: StableDiffusion_1,
cutecat_init: Image.Image,
@ -658,7 +658,7 @@ def test_rectangular_init_latents(
assert sd15.lda.decode_latents(x).size == (width, height)
@torch.no_grad()
@no_grad()
def test_diffusion_inpainting(
sd15_inpainting: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image,
@ -692,7 +692,7 @@ def test_diffusion_inpainting(
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
@torch.no_grad()
@no_grad()
def test_diffusion_inpainting_float16(
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image,
@ -727,7 +727,7 @@ def test_diffusion_inpainting_float16(
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
@torch.no_grad()
@no_grad()
def test_diffusion_controlnet(
sd15_std: StableDiffusion_1,
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
@ -770,7 +770,7 @@ def test_diffusion_controlnet(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_controlnet_structural_copy(
sd15_std: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
@ -814,7 +814,7 @@ def test_diffusion_controlnet_structural_copy(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_controlnet_float16(
sd15_std_float16: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
@ -857,7 +857,7 @@ def test_diffusion_controlnet_float16(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_controlnet_stack(
sd15_std: StableDiffusion_1,
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
@ -912,7 +912,7 @@ def test_diffusion_controlnet_stack(
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_lora(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
@ -949,7 +949,7 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_lora_float16(
sd15_std_float16: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
@ -986,7 +986,7 @@ def test_diffusion_lora_float16(
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_lora_twice(
sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path],
@ -1025,7 +1025,7 @@ def test_diffusion_lora_twice(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1,
condition_image_refonly: Image.Image,
@ -1061,7 +1061,7 @@ def test_diffusion_refonly(
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
@torch.no_grad()
@no_grad()
def test_diffusion_inpainting_refonly(
sd15_inpainting: StableDiffusion_1_Inpainting,
scene_image_inpainting_refonly: Image.Image,
@ -1106,7 +1106,7 @@ def test_diffusion_inpainting_refonly(
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
@torch.no_grad()
@no_grad()
def test_diffusion_textual_inversion_random_init(
sd15_std: StableDiffusion_1,
expected_image_textual_inversion_random_init: Image.Image,
@ -1141,7 +1141,7 @@ def test_diffusion_textual_inversion_random_init(
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path,
@ -1196,7 +1196,7 @@ def test_diffusion_ip_adapter(
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@torch.no_grad()
@no_grad()
def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path,
@ -1215,7 +1215,7 @@ def test_diffusion_sdxl_ip_adapter(
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject()
with torch.no_grad():
with no_grad():
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt
)
@ -1236,7 +1236,7 @@ def test_diffusion_sdxl_ip_adapter(
manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
with torch.no_grad():
with no_grad():
for step in sdxl.steps:
x = sdxl(
x,
@ -1254,7 +1254,7 @@ def test_diffusion_sdxl_ip_adapter(
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@torch.no_grad()
@no_grad()
def test_diffusion_ip_adapter_controlnet(
sd15_ddim: StableDiffusion_1,
ip_adapter_weights: Path,
@ -1320,7 +1320,7 @@ def test_diffusion_ip_adapter_controlnet(
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@torch.no_grad()
@no_grad()
def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path,
@ -1371,7 +1371,7 @@ def test_diffusion_ip_adapter_plus(
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path,
@ -1427,7 +1427,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@torch.no_grad()
@no_grad()
def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
) -> None:
@ -1462,7 +1462,7 @@ def test_sdxl_random_init(
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_sdxl_random_init_sag(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
) -> None:
@ -1498,7 +1498,7 @@ def test_sdxl_random_init_sag(
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@torch.no_grad()
@no_grad()
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
manual_seed(seed=2)
sd = sd15_ddim
@ -1529,7 +1529,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_t2i_adapter_depth(
sd15_std: StableDiffusion_1,
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
@ -1570,7 +1570,7 @@ def test_t2i_adapter_depth(
ensure_similar_images(predicted_image, expected_image)
@torch.no_grad()
@no_grad()
def test_t2i_adapter_xl_canny(
sdxl_ddim: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
@ -1619,7 +1619,7 @@ def test_t2i_adapter_xl_canny(
ensure_similar_images(predicted_image, expected_image)
@torch.no_grad()
@no_grad()
def test_restart(
sd15_ddim: StableDiffusion_1,
expected_restart: Image.Image,
@ -1659,7 +1659,7 @@ def test_restart(
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
@torch.no_grad()
@no_grad()
def test_freeu(
sd15_std: StableDiffusion_1,
expected_freeu: Image.Image,

View file

@ -5,7 +5,7 @@ import pytest
import torch
from PIL import Image
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from tests.utils import ensure_similar_images
@ -41,7 +41,7 @@ def informative_drawings_model(informative_drawings_weights: Path, test_device:
return model
@torch.no_grad()
@no_grad()
def test_preprocessor_informative_drawing(
informative_drawings_model: InformativeDrawings,
cutecat_init: Image.Image,

View file

@ -1,3 +1,4 @@
from typing import Any, Callable
from warnings import warn
import pytest
@ -60,8 +61,9 @@ def test_converter_multiple_tensors(test_device: torch.device) -> None:
def test_converter_no_parent_device_or_dtype() -> None:
identity: Callable[[Any], Any] = lambda x: x
chain = fl.Chain(
fl.Lambda(func=(lambda x: x)),
fl.Lambda(func=identity),
fl.Converter(set_device=True, set_dtype=False),
)

View file

@ -7,7 +7,7 @@ from PIL import Image
from torch import device as Device, dtype as DType
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image
@dataclass
@ -62,3 +62,18 @@ def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
def test_no_grad() -> None:
x = torch.randn(1, 1, requires_grad=True)
with torch.no_grad():
y = x + 1
assert not y.requires_grad
with no_grad():
z = x + 1
assert not z.requires_grad
w = x + 1
assert w.requires_grad

View file

@ -7,7 +7,7 @@ import transformers # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -124,7 +124,7 @@ def test_encoder(
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad():
with no_grad():
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder_with_new_concepts(prompt)

View file

@ -5,7 +5,7 @@ import pytest
import torch
from transformers import CLIPVisionModelWithProjection # type: ignore
from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@ -44,7 +44,7 @@ def test_encoder(
):
x = torch.randn(1, 3, 224, 224).to(test_device)
with torch.no_grad():
with no_grad():
ref_embeddings = ref_encoder(x).image_embeds
our_embeddings = our_encoder(x)

View file

@ -5,7 +5,7 @@ import pytest
import torch
import transformers # type: ignore
from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -89,7 +89,7 @@ def test_encoder(
our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad():
with no_grad():
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder(prompt)

View file

@ -7,7 +7,7 @@ import torch
from transformers import AutoModel # type: ignore
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
from refiners.fluxion.utils import load_from_safetensors, manual_seed
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.dinov2 import (
DINOv2_base,
DINOv2_base_reg,
@ -124,7 +124,7 @@ def test_encoder(
x = torch.randn(1, 3, 518, 518).to(test_device)
with torch.no_grad():
with no_grad():
ref_features = ref_backbone(x).last_hidden_state
our_features = our_backbone(x)

View file

@ -6,7 +6,7 @@ import torch
from PIL import Image
from tests.utils import ensure_similar_images
from refiners.fluxion.utils import load_from_safetensors
from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@ -38,7 +38,7 @@ def sample_image(ref_path: Path) -> Image.Image:
return img
@torch.no_grad()
@no_grad()
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.encode_image(sample_image)
decoded = encoder.decode_latents(encoded)

View file

@ -1,10 +1,10 @@
from typing import Iterator
import pytest
import torch
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import lookup_top_adapter
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
@ -18,7 +18,7 @@ def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
yield unet
@torch.no_grad()
@no_grad()
def test_single_controlnet(unet: SD1UNet) -> None:
original_parent = unet.parent
cn = SD1ControlnetAdapter(unet, name="cn")
@ -43,7 +43,7 @@ def test_single_controlnet(unet: SD1UNet) -> None:
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
@no_grad()
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
@ -71,7 +71,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
@no_grad()
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
@ -86,7 +86,7 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad()
@no_grad()
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
SD1ControlnetAdapter(unet, name="cnx").inject()
cn2 = SD1ControlnetAdapter(unet, name="cnx")

View file

@ -4,6 +4,7 @@ import pytest
import torch
from refiners.fluxion import manual_seed
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
@ -52,14 +53,14 @@ def test_freeu_identity_scales() -> None:
unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad():
with no_grad():
unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone())
freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0])
freeu.inject()
with torch.no_grad():
with no_grad():
unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone())

View file

@ -1,6 +1,6 @@
import pytest
import torch
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from refiners.foundationals.latent_diffusion.reference_only_control import (
@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import (
)
@torch.no_grad()
@no_grad()
def test_refonly_inject_eject() -> None:
unet = SD1UNet(in_channels=9)
adapter = ReferenceOnlyControlAdapter(unet)

View file

@ -7,7 +7,7 @@ import torch
from torch import Tensor
import refiners.fluxion.layers as fl
from refiners.fluxion.utils import manual_seed
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
@ -65,7 +65,7 @@ def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
return double_text_encoder
@torch.no_grad()
@no_grad()
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0)
prompt = "A photo of a pizza."

View file

@ -6,7 +6,7 @@ import pytest
import torch
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed
from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
@ -37,7 +37,7 @@ def refiners_sdxl_unet() -> SDXLUNet:
return unet
@torch.no_grad()
@no_grad()
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
source = diffusers_sdxl_unet
target = refiners_sdxl_unet

View file

@ -1,6 +1,7 @@
import torch
from refiners.fluxion import manual_seed
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet
@ -13,11 +14,11 @@ def test_unet_context_flush():
unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad():
with no_grad():
unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone())
with torch.no_grad():
with no_grad():
unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone())

View file

@ -18,7 +18,7 @@ from torch import Tensor
from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor
from refiners.fluxion.utils import image_to_tensor, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
@ -98,7 +98,7 @@ def truck(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "truck.jpg").convert("RGB")
@torch.no_grad()
@no_grad()
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
manual_seed(seed=0)
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
@ -124,7 +124,7 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
assert torch.equal(input=y_1, other=y_2)
@torch.no_grad()
@no_grad()
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
y_1 = facebook_sam_h.image_encoder(image_tensor)
@ -133,7 +133,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru
assert torch.allclose(input=y_1, other=y_2, atol=1e-4)
@torch.no_grad()
@no_grad()
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.point_encoder
@ -144,7 +144,7 @@ def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM,
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
@torch.no_grad()
@no_grad()
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.mask_encoder
@ -155,7 +155,7 @@ def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
@torch.no_grad()
@no_grad()
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.point_encoder
@ -174,7 +174,7 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
@torch.no_grad()
@no_grad()
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
@ -223,7 +223,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
assert torch.equal(input=y_1, other=y_2)
@torch.no_grad()
@no_grad()
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
manual_seed(seed=0)
facebook_mask_decoder = facebook_sam_h.mask_decoder