upgrade pyright to 1.1.342 ; improve no_grad typing

This commit is contained in:
limiteinductive 2023-12-29 10:59:51 +01:00 committed by Benjamin Trom
parent 7b14b4d981
commit 20c229903f
30 changed files with 136 additions and 95 deletions

View file

@ -92,7 +92,7 @@ from PIL import Image
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors
# Load inputs # Load inputs
init_image = Image.open("dropy_logo.png") init_image = Image.open("dropy_logo.png")
@ -122,7 +122,7 @@ t2i_adapter.set_scale(0.8)
sdxl.set_num_inference_steps(50) sdxl.set_num_inference_steps(50)
sdxl.set_self_attention_guidance(enable=True, scale=0.75) sdxl.set_self_attention_guidance(enable=True, scale=0.75)
with torch.no_grad(): with no_grad():
# Note: default text prompts for IP-Adapter # Note: default text prompts for IP-Adapter
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality" text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"

View file

@ -54,7 +54,7 @@ build-backend = "hatchling.build"
[tool.rye] [tool.rye]
managed = true managed = true
dev-dependencies = [ dev-dependencies = [
"pyright == 1.1.333", "pyright == 1.1.342",
"ruff>=0.0.292", "ruff>=0.0.292",
"docformatter>=1.7.5", "docformatter>=1.7.5",
"pytest>=7.4.2", "pytest>=7.4.2",

View file

@ -7,7 +7,7 @@ from diffusers import ControlNetModel # type: ignore
from torch import nn from torch import nn
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import ( from refiners.foundationals.latent_diffusion import (
DPMSolver, DPMSolver,
SD1ControlnetAdapter, SD1ControlnetAdapter,
@ -20,7 +20,7 @@ class Args(argparse.Namespace):
output_path: str | None output_path: str | None
@torch.no_grad() @no_grad()
def convert(args: Args) -> dict[str, torch.Tensor]: def convert(args: Args) -> dict[str, torch.Tensor]:
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore

View file

@ -11,7 +11,7 @@ from torch.nn.init import zeros_
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.lora import Lora, LoraAdapter from refiners.fluxion.adapters.lora import Lora, LoraAdapter
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import save_to_safetensors from refiners.fluxion.utils import no_grad, save_to_safetensors
from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
@ -37,7 +37,7 @@ class Args(argparse.Namespace):
verbose: bool verbose: bool
@torch.no_grad() @no_grad()
def process(args: Args) -> None: def process(args: Args) -> None:
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate` # low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`

View file

@ -1,3 +1,5 @@
from typing import Callable
from torch import Size, Tensor, device as Device, dtype as DType from torch import Size, Tensor, device as Device, dtype as DType
from torch.nn.functional import pad from torch.nn.functional import pad
@ -40,7 +42,8 @@ class Downsample(Chain):
), ),
) )
if padding == 0: if padding == 0:
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1)))) zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
self.insert(0, Lambda(zero_pad))
if register_shape: if register_shape:
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape)) self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))

View file

@ -7,7 +7,7 @@ import torch
from torch import Tensor, nn from torch import Tensor, nn
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
from refiners.fluxion.utils import norm, save_to_safetensors from refiners.fluxion.utils import no_grad, norm, save_to_safetensors
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [ TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
nn.Conv1d, nn.Conv1d,
@ -512,7 +512,7 @@ class ModelConverter:
return True return True
@torch.no_grad() @no_grad()
def _trace_module_execution_order( def _trace_module_execution_order(
self, self,
module: nn.Module, module: nn.Module,
@ -603,7 +603,7 @@ class ModelConverter:
return converted_state_dict return converted_state_dict
@torch.no_grad() @no_grad()
def _collect_layers_outputs( def _collect_layers_outputs(
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str] self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
) -> list[tuple[str, Tensor]]: ) -> list[tuple[str, Tensor]]:

View file

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Iterable, Literal, TypeVar from typing import Any, Iterable, Literal, TypeVar
import torch import torch
from jaxtyping import Float from jaxtyping import Float
@ -7,7 +7,14 @@ from numpy import array, float32
from PIL import Image from PIL import Image
from safetensors import safe_open as _safe_open # type: ignore from safetensors import safe_open as _safe_open # type: ignore
from safetensors.torch import save_file as _save_file # type: ignore from safetensors.torch import save_file as _save_file # type: ignore
from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore from torch import (
Tensor,
device as Device,
dtype as DType,
manual_seed as _manual_seed, # type: ignore
no_grad as _no_grad, # type: ignore
norm as _norm, # type: ignore
)
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
T = TypeVar("T") T = TypeVar("T")
@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None:
_manual_seed(seed) _manual_seed(seed)
class no_grad(_no_grad):
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
return object.__new__(cls)
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor: def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore

View file

@ -1,4 +1,6 @@
from torch import device as Device, dtype as DType from typing import Callable
from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
@ -126,6 +128,7 @@ class CLIPImageEncoder(fl.Chain):
self.num_layers = num_layers self.num_layers = num_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim self.feedforward_dim = feedforward_dim
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
super().__init__( super().__init__(
ViTEmbeddings( ViTEmbeddings(
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
@ -142,7 +145,7 @@ class CLIPImageEncoder(fl.Chain):
) )
for _ in range(num_layers) for _ in range(num_layers)
), ),
fl.Lambda(func=lambda x: x[:, 0, :]), fl.Lambda(func=cls_token_pooling),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype), fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype), fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
) )

View file

@ -1,5 +1,5 @@
import math import math
from typing import Any, Generic, TypeVar from typing import Any, Callable, Generic, TypeVar
import torch import torch
from torch import Tensor from torch import Tensor
@ -54,9 +54,10 @@ class FreeUBackboneFeatures(fl.Module):
class FreeUSkipFeatures(fl.Chain): class FreeUSkipFeatures(fl.Chain):
def __init__(self, n: int, skip_scale: float) -> None: def __init__(self, n: int, skip_scale: float) -> None:
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
super().__init__( super().__init__(
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]), fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)), fl.Lambda(apply_filter),
) )

View file

@ -122,7 +122,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata" assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
tensors = load_from_safetensors(checkpoint_path, device=target.device) tensors = load_from_safetensors(checkpoint_path, device=target.device)
sub_targets = {} sub_targets: dict[str, list[LoraTarget]] = {}
for model_name in MODELS: for model_name in MODELS:
if not (v := metadata.get(f"{model_name}_targets", "")): if not (v := metadata.get(f"{model_name}_targets", "")):
continue continue

View file

@ -1,3 +1,5 @@
from typing import Callable
from torch import Tensor from torch import Tensor
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
@ -45,8 +47,9 @@ class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
) )
with self.setup_adapter(target): with self.setup_adapter(target):
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
super().__init__( super().__init__(
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)), Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
Lambda(self.compute_averaged_unconditioned_x), Lambda(self.compute_averaged_unconditioned_x),
) )

View file

@ -7,7 +7,7 @@ from PIL import Image
from torch import Tensor, device as Device, dtype as DType from torch import Tensor, device as Device, dtype as DType
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
@ -39,7 +39,7 @@ class SegmentAnything(fl.Module):
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype) self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype) self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
@torch.no_grad() @no_grad()
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding: def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
original_size = (image.height, image.width) original_size = (image.height, image.width)
target_size = self.compute_target_size(original_size) target_size = self.compute_target_size(original_size)
@ -48,7 +48,7 @@ class SegmentAnything(fl.Module):
original_image_size=original_size, original_image_size=original_size,
) )
@torch.no_grad() @no_grad()
def predict( def predict(
self, self,
input: Image.Image | ImageEmbedding, input: Image.Image | ImageEmbedding,

View file

@ -250,7 +250,7 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
self.timestep_bins[bin_index].append(loss_value) self.timestep_bins[bin_index].append(loss_value)
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None: def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
log_data = {} log_data: dict[str, WandbLoggable] = {}
for bin_index, losses in self.timestep_bins.items(): for bin_index, losses in self.timestep_bins.items():
if losses: if losses:
avg_loss = sum(losses) / len(losses) avg_loss = sum(losses) / len(losses)

View file

@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
from torch.autograd import backward from torch.autograd import backward
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
@ -26,7 +26,7 @@ from torch.optim.lr_scheduler import (
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from refiners.fluxion import layers as fl from refiners.fluxion import layers as fl
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed, no_grad
from refiners.training_utils.callback import ( from refiners.training_utils.callback import (
Callback, Callback,
ClockCallback, ClockCallback,

View file

@ -6,7 +6,7 @@ import pytest
import torch import torch
from PIL import Image from PIL import Image
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion import ( from refiners.foundationals.latent_diffusion import (
SD1ControlnetAdapter, SD1ControlnetAdapter,
@ -501,7 +501,7 @@ def sdxl_ddim(
return sdxl return sdxl
@torch.no_grad() @no_grad()
def test_diffusion_std_random_init( def test_diffusion_std_random_init(
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
): ):
@ -529,7 +529,7 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init) ensure_similar_images(predicted_image, expected_image_std_random_init)
@torch.no_grad() @no_grad()
def test_diffusion_karras_random_init( def test_diffusion_karras_random_init(
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
): ):
@ -554,7 +554,7 @@ def test_diffusion_karras_random_init(
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_std_random_init_float16( def test_diffusion_std_random_init_float16(
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
): ):
@ -583,7 +583,7 @@ def test_diffusion_std_random_init_float16(
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_std_random_init_sag( def test_diffusion_std_random_init_sag(
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
): ):
@ -612,7 +612,7 @@ def test_diffusion_std_random_init_sag(
ensure_similar_images(predicted_image, expected_image_std_random_init_sag) ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
@torch.no_grad() @no_grad()
def test_diffusion_std_init_image( def test_diffusion_std_init_image(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
cutecat_init: Image.Image, cutecat_init: Image.Image,
@ -643,7 +643,7 @@ def test_diffusion_std_init_image(
ensure_similar_images(predicted_image, expected_image_std_init_image) ensure_similar_images(predicted_image, expected_image_std_init_image)
@torch.no_grad() @no_grad()
def test_rectangular_init_latents( def test_rectangular_init_latents(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
cutecat_init: Image.Image, cutecat_init: Image.Image,
@ -658,7 +658,7 @@ def test_rectangular_init_latents(
assert sd15.lda.decode_latents(x).size == (width, height) assert sd15.lda.decode_latents(x).size == (width, height)
@torch.no_grad() @no_grad()
def test_diffusion_inpainting( def test_diffusion_inpainting(
sd15_inpainting: StableDiffusion_1_Inpainting, sd15_inpainting: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image, kitchen_dog: Image.Image,
@ -692,7 +692,7 @@ def test_diffusion_inpainting(
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95) ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
@torch.no_grad() @no_grad()
def test_diffusion_inpainting_float16( def test_diffusion_inpainting_float16(
sd15_inpainting_float16: StableDiffusion_1_Inpainting, sd15_inpainting_float16: StableDiffusion_1_Inpainting,
kitchen_dog: Image.Image, kitchen_dog: Image.Image,
@ -727,7 +727,7 @@ def test_diffusion_inpainting_float16(
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92) ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
@torch.no_grad() @no_grad()
def test_diffusion_controlnet( def test_diffusion_controlnet(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
controlnet_data: tuple[str, Image.Image, Image.Image, Path], controlnet_data: tuple[str, Image.Image, Image.Image, Path],
@ -770,7 +770,7 @@ def test_diffusion_controlnet(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_controlnet_structural_copy( def test_diffusion_controlnet_structural_copy(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
@ -814,7 +814,7 @@ def test_diffusion_controlnet_structural_copy(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_controlnet_float16( def test_diffusion_controlnet_float16(
sd15_std_float16: StableDiffusion_1, sd15_std_float16: StableDiffusion_1,
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path], controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
@ -857,7 +857,7 @@ def test_diffusion_controlnet_float16(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_controlnet_stack( def test_diffusion_controlnet_stack(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path], controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
@ -912,7 +912,7 @@ def test_diffusion_controlnet_stack(
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_lora( def test_diffusion_lora(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path], lora_data_pokemon: tuple[Image.Image, Path],
@ -949,7 +949,7 @@ def test_diffusion_lora(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_lora_float16( def test_diffusion_lora_float16(
sd15_std_float16: StableDiffusion_1, sd15_std_float16: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path], lora_data_pokemon: tuple[Image.Image, Path],
@ -986,7 +986,7 @@ def test_diffusion_lora_float16(
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_lora_twice( def test_diffusion_lora_twice(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
lora_data_pokemon: tuple[Image.Image, Path], lora_data_pokemon: tuple[Image.Image, Path],
@ -1025,7 +1025,7 @@ def test_diffusion_lora_twice(
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_refonly( def test_diffusion_refonly(
sd15_ddim: StableDiffusion_1, sd15_ddim: StableDiffusion_1,
condition_image_refonly: Image.Image, condition_image_refonly: Image.Image,
@ -1061,7 +1061,7 @@ def test_diffusion_refonly(
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99) ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
@torch.no_grad() @no_grad()
def test_diffusion_inpainting_refonly( def test_diffusion_inpainting_refonly(
sd15_inpainting: StableDiffusion_1_Inpainting, sd15_inpainting: StableDiffusion_1_Inpainting,
scene_image_inpainting_refonly: Image.Image, scene_image_inpainting_refonly: Image.Image,
@ -1106,7 +1106,7 @@ def test_diffusion_inpainting_refonly(
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99) ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
@torch.no_grad() @no_grad()
def test_diffusion_textual_inversion_random_init( def test_diffusion_textual_inversion_random_init(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
expected_image_textual_inversion_random_init: Image.Image, expected_image_textual_inversion_random_init: Image.Image,
@ -1141,7 +1141,7 @@ def test_diffusion_textual_inversion_random_init(
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_ip_adapter( def test_diffusion_ip_adapter(
sd15_ddim_lda_ft_mse: StableDiffusion_1, sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_weights: Path, ip_adapter_weights: Path,
@ -1196,7 +1196,7 @@ def test_diffusion_ip_adapter(
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman) ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
@torch.no_grad() @no_grad()
def test_diffusion_sdxl_ip_adapter( def test_diffusion_sdxl_ip_adapter(
sdxl_ddim: StableDiffusion_XL, sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_weights: Path, sdxl_ip_adapter_weights: Path,
@ -1215,7 +1215,7 @@ def test_diffusion_sdxl_ip_adapter(
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights) ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
ip_adapter.inject() ip_adapter.inject()
with torch.no_grad(): with no_grad():
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding( clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
text=prompt, negative_text=negative_prompt text=prompt, negative_text=negative_prompt
) )
@ -1236,7 +1236,7 @@ def test_diffusion_sdxl_ip_adapter(
manual_seed(2) manual_seed(2)
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16) x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
with torch.no_grad(): with no_grad():
for step in sdxl.steps: for step in sdxl.steps:
x = sdxl( x = sdxl(
x, x,
@ -1254,7 +1254,7 @@ def test_diffusion_sdxl_ip_adapter(
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
@torch.no_grad() @no_grad()
def test_diffusion_ip_adapter_controlnet( def test_diffusion_ip_adapter_controlnet(
sd15_ddim: StableDiffusion_1, sd15_ddim: StableDiffusion_1,
ip_adapter_weights: Path, ip_adapter_weights: Path,
@ -1320,7 +1320,7 @@ def test_diffusion_ip_adapter_controlnet(
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet) ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
@torch.no_grad() @no_grad()
def test_diffusion_ip_adapter_plus( def test_diffusion_ip_adapter_plus(
sd15_ddim_lda_ft_mse: StableDiffusion_1, sd15_ddim_lda_ft_mse: StableDiffusion_1,
ip_adapter_plus_weights: Path, ip_adapter_plus_weights: Path,
@ -1371,7 +1371,7 @@ def test_diffusion_ip_adapter_plus(
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_diffusion_sdxl_ip_adapter_plus( def test_diffusion_sdxl_ip_adapter_plus(
sdxl_ddim: StableDiffusion_XL, sdxl_ddim: StableDiffusion_XL,
sdxl_ip_adapter_plus_weights: Path, sdxl_ip_adapter_plus_weights: Path,
@ -1427,7 +1427,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman) ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
@torch.no_grad() @no_grad()
def test_sdxl_random_init( def test_sdxl_random_init(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
) -> None: ) -> None:
@ -1462,7 +1462,7 @@ def test_sdxl_random_init(
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98) ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_sdxl_random_init_sag( def test_sdxl_random_init_sag(
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
) -> None: ) -> None:
@ -1498,7 +1498,7 @@ def test_sdxl_random_init_sag(
ensure_similar_images(img_1=predicted_image, img_2=expected_image) ensure_similar_images(img_1=predicted_image, img_2=expected_image)
@torch.no_grad() @no_grad()
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None: def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
manual_seed(seed=2) manual_seed(seed=2)
sd = sd15_ddim sd = sd15_ddim
@ -1529,7 +1529,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98) ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_t2i_adapter_depth( def test_t2i_adapter_depth(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path], t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
@ -1570,7 +1570,7 @@ def test_t2i_adapter_depth(
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)
@torch.no_grad() @no_grad()
def test_t2i_adapter_xl_canny( def test_t2i_adapter_xl_canny(
sdxl_ddim: StableDiffusion_XL, sdxl_ddim: StableDiffusion_XL,
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path], t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
@ -1619,7 +1619,7 @@ def test_t2i_adapter_xl_canny(
ensure_similar_images(predicted_image, expected_image) ensure_similar_images(predicted_image, expected_image)
@torch.no_grad() @no_grad()
def test_restart( def test_restart(
sd15_ddim: StableDiffusion_1, sd15_ddim: StableDiffusion_1,
expected_restart: Image.Image, expected_restart: Image.Image,
@ -1659,7 +1659,7 @@ def test_restart(
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98) ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
@torch.no_grad() @no_grad()
def test_freeu( def test_freeu(
sd15_std: StableDiffusion_1, sd15_std: StableDiffusion_1,
expected_freeu: Image.Image, expected_freeu: Image.Image,

View file

@ -5,7 +5,7 @@ import pytest
import torch import torch
from PIL import Image from PIL import Image
from refiners.fluxion.utils import image_to_tensor, tensor_to_image from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
@ -41,7 +41,7 @@ def informative_drawings_model(informative_drawings_weights: Path, test_device:
return model return model
@torch.no_grad() @no_grad()
def test_preprocessor_informative_drawing( def test_preprocessor_informative_drawing(
informative_drawings_model: InformativeDrawings, informative_drawings_model: InformativeDrawings,
cutecat_init: Image.Image, cutecat_init: Image.Image,

View file

@ -1,3 +1,4 @@
from typing import Any, Callable
from warnings import warn from warnings import warn
import pytest import pytest
@ -60,8 +61,9 @@ def test_converter_multiple_tensors(test_device: torch.device) -> None:
def test_converter_no_parent_device_or_dtype() -> None: def test_converter_no_parent_device_or_dtype() -> None:
identity: Callable[[Any], Any] = lambda x: x
chain = fl.Chain( chain = fl.Chain(
fl.Lambda(func=(lambda x: x)), fl.Lambda(func=identity),
fl.Converter(set_device=True, set_dtype=False), fl.Converter(set_device=True, set_dtype=False),
) )

View file

@ -7,7 +7,7 @@ from PIL import Image
from torch import device as Device, dtype as DType from torch import device as Device, dtype as DType
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image
@dataclass @dataclass
@ -62,3 +62,18 @@ def test_tensor_to_image() -> None:
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB" assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L" assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA" assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
def test_no_grad() -> None:
x = torch.randn(1, 1, requires_grad=True)
with torch.no_grad():
y = x + 1
assert not y.requires_grad
with no_grad():
z = x + 1
assert not z.requires_grad
w = x + 1
assert w.requires_grad

View file

@ -7,7 +7,7 @@ import transformers # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore from diffusers import StableDiffusionPipeline # type: ignore
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -124,7 +124,7 @@ def test_encoder(
our_tokens = tokenizer(prompt) our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens) assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad(): with no_grad():
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0] ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder_with_new_concepts(prompt) our_embeddings = our_encoder_with_new_concepts(prompt)

View file

@ -5,7 +5,7 @@ import pytest
import torch import torch
from transformers import CLIPVisionModelWithProjection # type: ignore from transformers import CLIPVisionModelWithProjection # type: ignore
from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
@ -44,7 +44,7 @@ def test_encoder(
): ):
x = torch.randn(1, 3, 224, 224).to(test_device) x = torch.randn(1, 3, 224, 224).to(test_device)
with torch.no_grad(): with no_grad():
ref_embeddings = ref_encoder(x).image_embeds ref_embeddings = ref_encoder(x).image_embeds
our_embeddings = our_encoder(x) our_embeddings = our_encoder(x)

View file

@ -5,7 +5,7 @@ import pytest
import torch import torch
import transformers # type: ignore import transformers # type: ignore
from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
from refiners.foundationals.clip.tokenizer import CLIPTokenizer from refiners.foundationals.clip.tokenizer import CLIPTokenizer
@ -89,7 +89,7 @@ def test_encoder(
our_tokens = tokenizer(prompt) our_tokens = tokenizer(prompt)
assert torch.equal(our_tokens, ref_tokens) assert torch.equal(our_tokens, ref_tokens)
with torch.no_grad(): with no_grad():
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0] ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
our_embeddings = our_encoder(prompt) our_embeddings = our_encoder(prompt)

View file

@ -7,7 +7,7 @@ import torch
from transformers import AutoModel # type: ignore from transformers import AutoModel # type: ignore
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
from refiners.fluxion.utils import load_from_safetensors, manual_seed from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.dinov2 import ( from refiners.foundationals.dinov2 import (
DINOv2_base, DINOv2_base,
DINOv2_base_reg, DINOv2_base_reg,
@ -124,7 +124,7 @@ def test_encoder(
x = torch.randn(1, 3, 518, 518).to(test_device) x = torch.randn(1, 3, 518, 518).to(test_device)
with torch.no_grad(): with no_grad():
ref_features = ref_backbone(x).last_hidden_state ref_features = ref_backbone(x).last_hidden_state
our_features = our_backbone(x) our_features = our_backbone(x)

View file

@ -6,7 +6,7 @@ import torch
from PIL import Image from PIL import Image
from tests.utils import ensure_similar_images from tests.utils import ensure_similar_images
from refiners.fluxion.utils import load_from_safetensors from refiners.fluxion.utils import load_from_safetensors, no_grad
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
@ -38,7 +38,7 @@ def sample_image(ref_path: Path) -> Image.Image:
return img return img
@torch.no_grad() @no_grad()
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image): def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
encoded = encoder.encode_image(sample_image) encoded = encoder.encode_image(sample_image)
decoded = encoder.decode_latents(encoded) decoded = encoder.decode_latents(encoded)

View file

@ -1,10 +1,10 @@
from typing import Iterator from typing import Iterator
import pytest import pytest
import torch
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import lookup_top_adapter from refiners.fluxion.adapters.adapter import lookup_top_adapter
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
@ -18,7 +18,7 @@ def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
yield unet yield unet
@torch.no_grad() @no_grad()
def test_single_controlnet(unet: SD1UNet) -> None: def test_single_controlnet(unet: SD1UNet) -> None:
original_parent = unet.parent original_parent = unet.parent
cn = SD1ControlnetAdapter(unet, name="cn") cn = SD1ControlnetAdapter(unet, name="cn")
@ -43,7 +43,7 @@ def test_single_controlnet(unet: SD1UNet) -> None:
assert len(list(unet.walk(Controlnet))) == 0 assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad() @no_grad()
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None: def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
original_parent = unet.parent original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
@ -71,7 +71,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
assert len(list(unet.walk(Controlnet))) == 0 assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad() @no_grad()
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None: def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
original_parent = unet.parent original_parent = unet.parent
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject() cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
@ -86,7 +86,7 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
assert len(list(unet.walk(Controlnet))) == 0 assert len(list(unet.walk(Controlnet))) == 0
@torch.no_grad() @no_grad()
def test_two_controlnets_same_name(unet: SD1UNet) -> None: def test_two_controlnets_same_name(unet: SD1UNet) -> None:
SD1ControlnetAdapter(unet, name="cnx").inject() SD1ControlnetAdapter(unet, name="cnx").inject()
cn2 = SD1ControlnetAdapter(unet, name="cnx") cn2 = SD1ControlnetAdapter(unet, name="cnx")

View file

@ -4,6 +4,7 @@ import pytest
import torch import torch
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
@ -52,14 +53,14 @@ def test_freeu_identity_scales() -> None:
unet = SD1UNet(in_channels=4) unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad(): with no_grad():
unet.set_timestep(timestep=timestep) unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone()) y_1 = unet(x.clone())
freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0]) freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0])
freeu.inject() freeu.inject()
with torch.no_grad(): with no_grad():
unet.set_timestep(timestep=timestep) unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone()) y_2 = unet(x.clone())

View file

@ -1,6 +1,6 @@
import pytest import pytest
import torch
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
from refiners.foundationals.latent_diffusion.reference_only_control import ( from refiners.foundationals.latent_diffusion.reference_only_control import (
@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import (
) )
@torch.no_grad() @no_grad()
def test_refonly_inject_eject() -> None: def test_refonly_inject_eject() -> None:
unet = SD1UNet(in_channels=9) unet = SD1UNet(in_channels=9)
adapter = ReferenceOnlyControlAdapter(unet) adapter = ReferenceOnlyControlAdapter(unet)

View file

@ -7,7 +7,7 @@ import torch
from torch import Tensor from torch import Tensor
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
@ -65,7 +65,7 @@ def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
return double_text_encoder return double_text_encoder
@torch.no_grad() @no_grad()
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None: def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
manual_seed(seed=0) manual_seed(seed=0)
prompt = "A photo of a pizza." prompt = "A photo of a pizza."

View file

@ -6,7 +6,7 @@ import pytest
import torch import torch
from refiners.fluxion.model_converter import ConversionStage, ModelConverter from refiners.fluxion.model_converter import ConversionStage, ModelConverter
from refiners.fluxion.utils import manual_seed from refiners.fluxion.utils import manual_seed, no_grad
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
@ -37,7 +37,7 @@ def refiners_sdxl_unet() -> SDXLUNet:
return unet return unet
@torch.no_grad() @no_grad()
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None: def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
source = diffusers_sdxl_unet source = diffusers_sdxl_unet
target = refiners_sdxl_unet target = refiners_sdxl_unet

View file

@ -1,6 +1,7 @@
import torch import torch
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.utils import no_grad
from refiners.foundationals.latent_diffusion import SD1UNet from refiners.foundationals.latent_diffusion import SD1UNet
@ -13,11 +14,11 @@ def test_unet_context_flush():
unet = SD1UNet(in_channels=4) unet = SD1UNet(in_channels=4)
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
with torch.no_grad(): with no_grad():
unet.set_timestep(timestep=timestep) unet.set_timestep(timestep=timestep)
y_1 = unet(x.clone()) y_1 = unet(x.clone())
with torch.no_grad(): with no_grad():
unet.set_timestep(timestep=timestep) unet.set_timestep(timestep=timestep)
y_2 = unet(x.clone()) y_2 = unet(x.clone())

View file

@ -18,7 +18,7 @@ from torch import Tensor
from refiners.fluxion import manual_seed from refiners.fluxion import manual_seed
from refiners.fluxion.model_converter import ModelConverter from refiners.fluxion.model_converter import ModelConverter
from refiners.fluxion.utils import image_to_tensor from refiners.fluxion.utils import image_to_tensor, no_grad
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
from refiners.foundationals.segment_anything.model import SegmentAnythingH from refiners.foundationals.segment_anything.model import SegmentAnythingH
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
@ -98,7 +98,7 @@ def truck(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "truck.jpg").convert("RGB") return Image.open(ref_path / "truck.jpg").convert("RGB")
@torch.no_grad() @no_grad()
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None: def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
manual_seed(seed=0) manual_seed(seed=0)
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device) x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
@ -124,7 +124,7 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
assert torch.equal(input=y_1, other=y_2) assert torch.equal(input=y_1, other=y_2)
@torch.no_grad() @no_grad()
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None: def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device) image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
y_1 = facebook_sam_h.image_encoder(image_tensor) y_1 = facebook_sam_h.image_encoder(image_tensor)
@ -133,7 +133,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru
assert torch.allclose(input=y_1, other=y_2, atol=1e-4) assert torch.allclose(input=y_1, other=y_2, atol=1e-4)
@torch.no_grad() @no_grad()
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.point_encoder refiners_prompt_encoder = sam_h.point_encoder
@ -144,7 +144,7 @@ def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM,
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
@torch.no_grad() @no_grad()
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.mask_encoder refiners_prompt_encoder = sam_h.mask_encoder
@ -155,7 +155,7 @@ def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe) assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
@torch.no_grad() @no_grad()
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None: def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
facebook_prompt_encoder = facebook_sam_h.prompt_encoder facebook_prompt_encoder = facebook_sam_h.prompt_encoder
refiners_prompt_encoder = sam_h.point_encoder refiners_prompt_encoder = sam_h.point_encoder
@ -174,7 +174,7 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe) assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
@torch.no_grad() @no_grad()
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None: def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device) dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
@ -223,7 +223,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
assert torch.equal(input=y_1, other=y_2) assert torch.equal(input=y_1, other=y_2)
@torch.no_grad() @no_grad()
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None: def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
manual_seed(seed=0) manual_seed(seed=0)
facebook_mask_decoder = facebook_sam_h.mask_decoder facebook_mask_decoder = facebook_sam_h.mask_decoder