mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
upgrade pyright to 1.1.342 ; improve no_grad typing
This commit is contained in:
parent
7b14b4d981
commit
20c229903f
|
@ -92,7 +92,7 @@ from PIL import Image
|
|||
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
|
||||
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
|
||||
from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors
|
||||
from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors
|
||||
|
||||
# Load inputs
|
||||
init_image = Image.open("dropy_logo.png")
|
||||
|
@ -122,7 +122,7 @@ t2i_adapter.set_scale(0.8)
|
|||
sdxl.set_num_inference_steps(50)
|
||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
# Note: default text prompts for IP-Adapter
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
|
|
@ -54,7 +54,7 @@ build-backend = "hatchling.build"
|
|||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = [
|
||||
"pyright == 1.1.333",
|
||||
"pyright == 1.1.342",
|
||||
"ruff>=0.0.292",
|
||||
"docformatter>=1.7.5",
|
||||
"pytest>=7.4.2",
|
||||
|
|
|
@ -7,7 +7,7 @@ from diffusers import ControlNetModel # type: ignore
|
|||
from torch import nn
|
||||
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
DPMSolver,
|
||||
SD1ControlnetAdapter,
|
||||
|
@ -20,7 +20,7 @@ class Args(argparse.Namespace):
|
|||
output_path: str | None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
|
||||
|
|
|
@ -11,7 +11,7 @@ from torch.nn.init import zeros_
|
|||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import save_to_safetensors
|
||||
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||
|
||||
|
@ -37,7 +37,7 @@ class Args(argparse.Namespace):
|
|||
verbose: bool
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def process(args: Args) -> None:
|
||||
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Callable
|
||||
|
||||
from torch import Size, Tensor, device as Device, dtype as DType
|
||||
from torch.nn.functional import pad
|
||||
|
||||
|
@ -40,7 +42,8 @@ class Downsample(Chain):
|
|||
),
|
||||
)
|
||||
if padding == 0:
|
||||
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
|
||||
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
|
||||
self.insert(0, Lambda(zero_pad))
|
||||
if register_shape:
|
||||
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from torch import Tensor, nn
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from refiners.fluxion.utils import norm, save_to_safetensors
|
||||
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors
|
||||
|
||||
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
||||
nn.Conv1d,
|
||||
|
@ -512,7 +512,7 @@ class ModelConverter:
|
|||
|
||||
return True
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def _trace_module_execution_order(
|
||||
self,
|
||||
module: nn.Module,
|
||||
|
@ -603,7 +603,7 @@ class ModelConverter:
|
|||
|
||||
return converted_state_dict
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def _collect_layers_outputs(
|
||||
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||
) -> list[tuple[str, Tensor]]:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from pathlib import Path
|
||||
from typing import Iterable, Literal, TypeVar
|
||||
from typing import Any, Iterable, Literal, TypeVar
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float
|
||||
|
@ -7,7 +7,14 @@ from numpy import array, float32
|
|||
from PIL import Image
|
||||
from safetensors import safe_open as _safe_open # type: ignore
|
||||
from safetensors.torch import save_file as _save_file # type: ignore
|
||||
from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore
|
||||
from torch import (
|
||||
Tensor,
|
||||
device as Device,
|
||||
dtype as DType,
|
||||
manual_seed as _manual_seed, # type: ignore
|
||||
no_grad as _no_grad, # type: ignore
|
||||
norm as _norm, # type: ignore
|
||||
)
|
||||
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None:
|
|||
_manual_seed(seed)
|
||||
|
||||
|
||||
class no_grad(_no_grad):
|
||||
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
|
||||
return object.__new__(cls)
|
||||
|
||||
|
||||
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
|
||||
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
||||
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
from torch import device as Device, dtype as DType
|
||||
from typing import Callable
|
||||
|
||||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
|
||||
|
@ -126,6 +128,7 @@ class CLIPImageEncoder(fl.Chain):
|
|||
self.num_layers = num_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.feedforward_dim = feedforward_dim
|
||||
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
|
||||
super().__init__(
|
||||
ViTEmbeddings(
|
||||
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
|
||||
|
@ -142,7 +145,7 @@ class CLIPImageEncoder(fl.Chain):
|
|||
)
|
||||
for _ in range(num_layers)
|
||||
),
|
||||
fl.Lambda(func=lambda x: x[:, 0, :]),
|
||||
fl.Lambda(func=cls_token_pooling),
|
||||
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import math
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Callable, Generic, TypeVar
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -54,9 +54,10 @@ class FreeUBackboneFeatures(fl.Module):
|
|||
|
||||
class FreeUSkipFeatures(fl.Chain):
|
||||
def __init__(self, n: int, skip_scale: float) -> None:
|
||||
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
|
||||
super().__init__(
|
||||
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
|
||||
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)),
|
||||
fl.Lambda(apply_filter),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -122,7 +122,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
|||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||
|
||||
sub_targets = {}
|
||||
sub_targets: dict[str, list[LoraTarget]] = {}
|
||||
for model_name in MODELS:
|
||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||
continue
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Callable
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
@ -45,8 +47,9 @@ class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
|
|||
)
|
||||
|
||||
with self.setup_adapter(target):
|
||||
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
|
||||
super().__init__(
|
||||
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)),
|
||||
Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
|
||||
Lambda(self.compute_averaged_unconditioned_x),
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
from torch import Tensor, device as Device, dtype as DType
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad
|
||||
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
|
||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||
|
@ -39,7 +39,7 @@ class SegmentAnything(fl.Module):
|
|||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||
original_size = (image.height, image.width)
|
||||
target_size = self.compute_target_size(original_size)
|
||||
|
@ -48,7 +48,7 @@ class SegmentAnything(fl.Module):
|
|||
original_image_size=original_size,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def predict(
|
||||
self,
|
||||
input: Image.Image | ImageEmbedding,
|
||||
|
|
|
@ -250,7 +250,7 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
|
|||
self.timestep_bins[bin_index].append(loss_value)
|
||||
|
||||
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
|
||||
log_data = {}
|
||||
log_data: dict[str, WandbLoggable] = {}
|
||||
for bin_index, losses in self.timestep_bins.items():
|
||||
if losses:
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
|||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack
|
||||
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
|
||||
from torch.autograd import backward
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
@ -26,7 +26,7 @@ from torch.optim.lr_scheduler import (
|
|||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from refiners.fluxion import layers as fl
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
from refiners.fluxion.utils import manual_seed, no_grad
|
||||
from refiners.training_utils.callback import (
|
||||
Callback,
|
||||
ClockCallback,
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed
|
||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||
from refiners.foundationals.latent_diffusion import (
|
||||
SD1ControlnetAdapter,
|
||||
|
@ -501,7 +501,7 @@ def sdxl_ddim(
|
|||
return sdxl
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_std_random_init(
|
||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
|
@ -529,7 +529,7 @@ def test_diffusion_std_random_init(
|
|||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_karras_random_init(
|
||||
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
|
@ -554,7 +554,7 @@ def test_diffusion_karras_random_init(
|
|||
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_std_random_init_float16(
|
||||
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||
):
|
||||
|
@ -583,7 +583,7 @@ def test_diffusion_std_random_init_float16(
|
|||
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_std_random_init_sag(
|
||||
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
|
||||
):
|
||||
|
@ -612,7 +612,7 @@ def test_diffusion_std_random_init_sag(
|
|||
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_std_init_image(
|
||||
sd15_std: StableDiffusion_1,
|
||||
cutecat_init: Image.Image,
|
||||
|
@ -643,7 +643,7 @@ def test_diffusion_std_init_image(
|
|||
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_rectangular_init_latents(
|
||||
sd15_std: StableDiffusion_1,
|
||||
cutecat_init: Image.Image,
|
||||
|
@ -658,7 +658,7 @@ def test_rectangular_init_latents(
|
|||
assert sd15.lda.decode_latents(x).size == (width, height)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_inpainting(
|
||||
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||
kitchen_dog: Image.Image,
|
||||
|
@ -692,7 +692,7 @@ def test_diffusion_inpainting(
|
|||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_inpainting_float16(
|
||||
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
|
||||
kitchen_dog: Image.Image,
|
||||
|
@ -727,7 +727,7 @@ def test_diffusion_inpainting_float16(
|
|||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_controlnet(
|
||||
sd15_std: StableDiffusion_1,
|
||||
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
|
||||
|
@ -770,7 +770,7 @@ def test_diffusion_controlnet(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_controlnet_structural_copy(
|
||||
sd15_std: StableDiffusion_1,
|
||||
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||
|
@ -814,7 +814,7 @@ def test_diffusion_controlnet_structural_copy(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_controlnet_float16(
|
||||
sd15_std_float16: StableDiffusion_1,
|
||||
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||
|
@ -857,7 +857,7 @@ def test_diffusion_controlnet_float16(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_controlnet_stack(
|
||||
sd15_std: StableDiffusion_1,
|
||||
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||
|
@ -912,7 +912,7 @@ def test_diffusion_controlnet_stack(
|
|||
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_lora(
|
||||
sd15_std: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
|
@ -949,7 +949,7 @@ def test_diffusion_lora(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_lora_float16(
|
||||
sd15_std_float16: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
|
@ -986,7 +986,7 @@ def test_diffusion_lora_float16(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_lora_twice(
|
||||
sd15_std: StableDiffusion_1,
|
||||
lora_data_pokemon: tuple[Image.Image, Path],
|
||||
|
@ -1025,7 +1025,7 @@ def test_diffusion_lora_twice(
|
|||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_refonly(
|
||||
sd15_ddim: StableDiffusion_1,
|
||||
condition_image_refonly: Image.Image,
|
||||
|
@ -1061,7 +1061,7 @@ def test_diffusion_refonly(
|
|||
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_inpainting_refonly(
|
||||
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||
scene_image_inpainting_refonly: Image.Image,
|
||||
|
@ -1106,7 +1106,7 @@ def test_diffusion_inpainting_refonly(
|
|||
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_textual_inversion_random_init(
|
||||
sd15_std: StableDiffusion_1,
|
||||
expected_image_textual_inversion_random_init: Image.Image,
|
||||
|
@ -1141,7 +1141,7 @@ def test_diffusion_textual_inversion_random_init(
|
|||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_ip_adapter(
|
||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||
ip_adapter_weights: Path,
|
||||
|
@ -1196,7 +1196,7 @@ def test_diffusion_ip_adapter(
|
|||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_sdxl_ip_adapter(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
sdxl_ip_adapter_weights: Path,
|
||||
|
@ -1215,7 +1215,7 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||
ip_adapter.inject()
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||
text=prompt, negative_text=negative_prompt
|
||||
)
|
||||
|
@ -1236,7 +1236,7 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
manual_seed(2)
|
||||
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
for step in sdxl.steps:
|
||||
x = sdxl(
|
||||
x,
|
||||
|
@ -1254,7 +1254,7 @@ def test_diffusion_sdxl_ip_adapter(
|
|||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_ip_adapter_controlnet(
|
||||
sd15_ddim: StableDiffusion_1,
|
||||
ip_adapter_weights: Path,
|
||||
|
@ -1320,7 +1320,7 @@ def test_diffusion_ip_adapter_controlnet(
|
|||
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_ip_adapter_plus(
|
||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||
ip_adapter_plus_weights: Path,
|
||||
|
@ -1371,7 +1371,7 @@ def test_diffusion_ip_adapter_plus(
|
|||
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_diffusion_sdxl_ip_adapter_plus(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
sdxl_ip_adapter_plus_weights: Path,
|
||||
|
@ -1427,7 +1427,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
|||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_sdxl_random_init(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||
) -> None:
|
||||
|
@ -1462,7 +1462,7 @@ def test_sdxl_random_init(
|
|||
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_sdxl_random_init_sag(
|
||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
||||
) -> None:
|
||||
|
@ -1498,7 +1498,7 @@ def test_sdxl_random_init_sag(
|
|||
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||
manual_seed(seed=2)
|
||||
sd = sd15_ddim
|
||||
|
@ -1529,7 +1529,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
|||
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_t2i_adapter_depth(
|
||||
sd15_std: StableDiffusion_1,
|
||||
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||
|
@ -1570,7 +1570,7 @@ def test_t2i_adapter_depth(
|
|||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_t2i_adapter_xl_canny(
|
||||
sdxl_ddim: StableDiffusion_XL,
|
||||
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||
|
@ -1619,7 +1619,7 @@ def test_t2i_adapter_xl_canny(
|
|||
ensure_similar_images(predicted_image, expected_image)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_restart(
|
||||
sd15_ddim: StableDiffusion_1,
|
||||
expected_restart: Image.Image,
|
||||
|
@ -1659,7 +1659,7 @@ def test_restart(
|
|||
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_freeu(
|
||||
sd15_std: StableDiffusion_1,
|
||||
expected_freeu: Image.Image,
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
|
@ -41,7 +41,7 @@ def informative_drawings_model(informative_drawings_weights: Path, test_device:
|
|||
return model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_preprocessor_informative_drawing(
|
||||
informative_drawings_model: InformativeDrawings,
|
||||
cutecat_init: Image.Image,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Any, Callable
|
||||
from warnings import warn
|
||||
|
||||
import pytest
|
||||
|
@ -60,8 +61,9 @@ def test_converter_multiple_tensors(test_device: torch.device) -> None:
|
|||
|
||||
|
||||
def test_converter_no_parent_device_or_dtype() -> None:
|
||||
identity: Callable[[Any], Any] = lambda x: x
|
||||
chain = fl.Chain(
|
||||
fl.Lambda(func=(lambda x: x)),
|
||||
fl.Lambda(func=identity),
|
||||
fl.Converter(set_device=True, set_dtype=False),
|
||||
)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
|||
from torch import device as Device, dtype as DType
|
||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
|
||||
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -62,3 +62,18 @@ def test_tensor_to_image() -> None:
|
|||
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
||||
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
||||
|
||||
|
||||
def test_no_grad() -> None:
|
||||
x = torch.randn(1, 1, requires_grad=True)
|
||||
|
||||
with torch.no_grad():
|
||||
y = x + 1
|
||||
assert not y.requires_grad
|
||||
|
||||
with no_grad():
|
||||
z = x + 1
|
||||
assert not z.requires_grad
|
||||
|
||||
w = x + 1
|
||||
assert w.requires_grad
|
||||
|
|
|
@ -7,7 +7,7 @@ import transformers # type: ignore
|
|||
from diffusers import StableDiffusionPipeline # type: ignore
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
@ -124,7 +124,7 @@ def test_encoder(
|
|||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
|
||||
our_embeddings = our_encoder_with_new_concepts(prompt)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
import torch
|
||||
from transformers import CLIPVisionModelWithProjection # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||
|
||||
|
||||
|
@ -44,7 +44,7 @@ def test_encoder(
|
|||
):
|
||||
x = torch.randn(1, 3, 224, 224).to(test_device)
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
ref_embeddings = ref_encoder(x).image_embeds
|
||||
our_embeddings = our_encoder(x)
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
import torch
|
||||
import transformers # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||
|
||||
|
@ -89,7 +89,7 @@ def test_encoder(
|
|||
our_tokens = tokenizer(prompt)
|
||||
assert torch.equal(our_tokens, ref_tokens)
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
||||
our_embeddings = our_encoder(prompt)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from transformers import AutoModel # type: ignore
|
||||
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||
from refiners.foundationals.dinov2 import (
|
||||
DINOv2_base,
|
||||
DINOv2_base_reg,
|
||||
|
@ -124,7 +124,7 @@ def test_encoder(
|
|||
|
||||
x = torch.randn(1, 3, 518, 518).to(test_device)
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
ref_features = ref_backbone(x).last_hidden_state
|
||||
our_features = our_backbone(x)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
from PIL import Image
|
||||
from tests.utils import ensure_similar_images
|
||||
|
||||
from refiners.fluxion.utils import load_from_safetensors
|
||||
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||
|
||||
|
||||
|
@ -38,7 +38,7 @@ def sample_image(ref_path: Path) -> Image.Image:
|
|||
return img
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||
encoded = encoder.encode_image(sample_image)
|
||||
decoded = encoder.decode_latents(encoded)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import lookup_top_adapter
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||
|
||||
|
@ -18,7 +18,7 @@ def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
|||
yield unet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_single_controlnet(unet: SD1UNet) -> None:
|
||||
original_parent = unet.parent
|
||||
cn = SD1ControlnetAdapter(unet, name="cn")
|
||||
|
@ -43,7 +43,7 @@ def test_single_controlnet(unet: SD1UNet) -> None:
|
|||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
||||
original_parent = unet.parent
|
||||
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||||
|
@ -71,7 +71,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
|||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
||||
original_parent = unet.parent
|
||||
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||||
|
@ -86,7 +86,7 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
|||
assert len(list(unet.walk(Controlnet))) == 0
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
|
||||
SD1ControlnetAdapter(unet, name="cnx").inject()
|
||||
cn2 = SD1ControlnetAdapter(unet, name="cnx")
|
||||
|
|
|
@ -4,6 +4,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
||||
|
||||
|
@ -52,14 +53,14 @@ def test_freeu_identity_scales() -> None:
|
|||
unet = SD1UNet(in_channels=4)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
unet.set_timestep(timestep=timestep)
|
||||
y_1 = unet(x.clone())
|
||||
|
||||
freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0])
|
||||
freeu.inject()
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
unet.set_timestep(timestep=timestep)
|
||||
y_2 = unet(x.clone())
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||
from refiners.foundationals.latent_diffusion.reference_only_control import (
|
||||
|
@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import (
|
|||
)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_refonly_inject_eject() -> None:
|
||||
unet = SD1UNet(in_channels=9)
|
||||
adapter = ReferenceOnlyControlAdapter(unet)
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
from refiners.fluxion.utils import manual_seed, no_grad
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
|
|||
return double_text_encoder
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
|
||||
manual_seed(seed=0)
|
||||
prompt = "A photo of a pizza."
|
||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
|||
import torch
|
||||
|
||||
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
||||
from refiners.fluxion.utils import manual_seed
|
||||
from refiners.fluxion.utils import manual_seed, no_grad
|
||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||
|
||||
|
||||
|
@ -37,7 +37,7 @@ def refiners_sdxl_unet() -> SDXLUNet:
|
|||
return unet
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
||||
source = diffusers_sdxl_unet
|
||||
target = refiners_sdxl_unet
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.fluxion.utils import no_grad
|
||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||
|
||||
|
||||
|
@ -13,11 +14,11 @@ def test_unet_context_flush():
|
|||
unet = SD1UNet(in_channels=4)
|
||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
unet.set_timestep(timestep=timestep)
|
||||
y_1 = unet(x.clone())
|
||||
|
||||
with torch.no_grad():
|
||||
with no_grad():
|
||||
unet.set_timestep(timestep=timestep)
|
||||
y_2 = unet(x.clone())
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
|||
|
||||
from refiners.fluxion import manual_seed
|
||||
from refiners.fluxion.model_converter import ModelConverter
|
||||
from refiners.fluxion.utils import image_to_tensor
|
||||
from refiners.fluxion.utils import image_to_tensor, no_grad
|
||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||
|
@ -98,7 +98,7 @@ def truck(ref_path: Path) -> Image.Image:
|
|||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
||||
manual_seed(seed=0)
|
||||
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
|
||||
|
@ -124,7 +124,7 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
|||
assert torch.equal(input=y_1, other=y_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
||||
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
||||
|
@ -133,7 +133,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru
|
|||
assert torch.allclose(input=y_1, other=y_2, atol=1e-4)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||
refiners_prompt_encoder = sam_h.point_encoder
|
||||
|
@ -144,7 +144,7 @@ def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM,
|
|||
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||
refiners_prompt_encoder = sam_h.mask_encoder
|
||||
|
@ -155,7 +155,7 @@ def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam
|
|||
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
|
||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||
refiners_prompt_encoder = sam_h.point_encoder
|
||||
|
@ -174,7 +174,7 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
|
|||
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
||||
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||
|
@ -223,7 +223,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
|||
assert torch.equal(input=y_1, other=y_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@no_grad()
|
||||
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||
manual_seed(seed=0)
|
||||
facebook_mask_decoder = facebook_sam_h.mask_decoder
|
||||
|
|
Loading…
Reference in a new issue