mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
upgrade pyright to 1.1.342 ; improve no_grad typing
This commit is contained in:
parent
7b14b4d981
commit
20c229903f
|
@ -92,7 +92,7 @@ from PIL import Image
|
||||||
|
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import StableDiffusion_XL
|
||||||
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
|
from refiners.foundationals.latent_diffusion import SDXLIPAdapter, SDXLT2IAdapter
|
||||||
from refiners.fluxion.utils import manual_seed, image_to_tensor, load_from_safetensors
|
from refiners.fluxion.utils import manual_seed, no_grad, image_to_tensor, load_from_safetensors
|
||||||
|
|
||||||
# Load inputs
|
# Load inputs
|
||||||
init_image = Image.open("dropy_logo.png")
|
init_image = Image.open("dropy_logo.png")
|
||||||
|
@ -122,7 +122,7 @@ t2i_adapter.set_scale(0.8)
|
||||||
sdxl.set_num_inference_steps(50)
|
sdxl.set_num_inference_steps(50)
|
||||||
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
sdxl.set_self_attention_guidance(enable=True, scale=0.75)
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
# Note: default text prompts for IP-Adapter
|
# Note: default text prompts for IP-Adapter
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||||
|
|
|
@ -54,7 +54,7 @@ build-backend = "hatchling.build"
|
||||||
[tool.rye]
|
[tool.rye]
|
||||||
managed = true
|
managed = true
|
||||||
dev-dependencies = [
|
dev-dependencies = [
|
||||||
"pyright == 1.1.333",
|
"pyright == 1.1.342",
|
||||||
"ruff>=0.0.292",
|
"ruff>=0.0.292",
|
||||||
"docformatter>=1.7.5",
|
"docformatter>=1.7.5",
|
||||||
"pytest>=7.4.2",
|
"pytest>=7.4.2",
|
||||||
|
|
|
@ -7,7 +7,7 @@ from diffusers import ControlNetModel # type: ignore
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
DPMSolver,
|
DPMSolver,
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
@ -20,7 +20,7 @@ class Args(argparse.Namespace):
|
||||||
output_path: str | None
|
output_path: str | None
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
|
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
|
||||||
|
|
|
@ -11,7 +11,7 @@ from torch.nn.init import zeros_
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
from refiners.fluxion.adapters.lora import Lora, LoraAdapter
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import save_to_safetensors
|
from refiners.fluxion.utils import no_grad, save_to_safetensors
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
from refiners.foundationals.latent_diffusion.lora import LoraTarget, lora_targets
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class Args(argparse.Namespace):
|
||||||
verbose: bool
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def process(args: Args) -> None:
|
def process(args: Args) -> None:
|
||||||
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
||||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from torch import Size, Tensor, device as Device, dtype as DType
|
from torch import Size, Tensor, device as Device, dtype as DType
|
||||||
from torch.nn.functional import pad
|
from torch.nn.functional import pad
|
||||||
|
|
||||||
|
@ -40,7 +42,8 @@ class Downsample(Chain):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if padding == 0:
|
if padding == 0:
|
||||||
self.insert(0, Lambda(lambda x: pad(x, (0, 1, 0, 1))))
|
zero_pad: Callable[[Tensor], Tensor] = lambda x: pad(x, (0, 1, 0, 1))
|
||||||
|
self.insert(0, Lambda(zero_pad))
|
||||||
if register_shape:
|
if register_shape:
|
||||||
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
self.insert(0, SetContext(context="sampling", key="shapes", callback=self.register_shape))
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
||||||
from refiners.fluxion.utils import norm, save_to_safetensors
|
from refiners.fluxion.utils import no_grad, norm, save_to_safetensors
|
||||||
|
|
||||||
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
TORCH_BASIC_LAYERS: list[type[nn.Module]] = [
|
||||||
nn.Conv1d,
|
nn.Conv1d,
|
||||||
|
@ -512,7 +512,7 @@ class ModelConverter:
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def _trace_module_execution_order(
|
def _trace_module_execution_order(
|
||||||
self,
|
self,
|
||||||
module: nn.Module,
|
module: nn.Module,
|
||||||
|
@ -603,7 +603,7 @@ class ModelConverter:
|
||||||
|
|
||||||
return converted_state_dict
|
return converted_state_dict
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def _collect_layers_outputs(
|
def _collect_layers_outputs(
|
||||||
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
self, module: nn.Module, args: ModuleArgs, keys_to_skip: list[str]
|
||||||
) -> list[tuple[str, Tensor]]:
|
) -> list[tuple[str, Tensor]]:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Literal, TypeVar
|
from typing import Any, Iterable, Literal, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from jaxtyping import Float
|
from jaxtyping import Float
|
||||||
|
@ -7,7 +7,14 @@ from numpy import array, float32
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors import safe_open as _safe_open # type: ignore
|
from safetensors import safe_open as _safe_open # type: ignore
|
||||||
from safetensors.torch import save_file as _save_file # type: ignore
|
from safetensors.torch import save_file as _save_file # type: ignore
|
||||||
from torch import Tensor, device as Device, dtype as DType, manual_seed as _manual_seed, norm as _norm # type: ignore
|
from torch import (
|
||||||
|
Tensor,
|
||||||
|
device as Device,
|
||||||
|
dtype as DType,
|
||||||
|
manual_seed as _manual_seed, # type: ignore
|
||||||
|
no_grad as _no_grad, # type: ignore
|
||||||
|
norm as _norm, # type: ignore
|
||||||
|
)
|
||||||
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
from torch.nn.functional import conv2d, interpolate as _interpolate, pad as _pad # type: ignore
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -22,6 +29,11 @@ def manual_seed(seed: int) -> None:
|
||||||
_manual_seed(seed)
|
_manual_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
class no_grad(_no_grad):
|
||||||
|
def __new__(cls, orig_func: Any | None = None) -> "no_grad": # type: ignore
|
||||||
|
return object.__new__(cls)
|
||||||
|
|
||||||
|
|
||||||
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
|
def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant") -> Tensor:
|
||||||
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
from torch import device as Device, dtype as DType
|
from typing import Callable
|
||||||
|
|
||||||
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
|
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
|
||||||
|
@ -126,6 +128,7 @@ class CLIPImageEncoder(fl.Chain):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
self.feedforward_dim = feedforward_dim
|
self.feedforward_dim = feedforward_dim
|
||||||
|
cls_token_pooling: Callable[[Tensor], Tensor] = lambda x: x[:, 0, :]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
ViTEmbeddings(
|
ViTEmbeddings(
|
||||||
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
|
image_size=image_size, embedding_dim=embedding_dim, patch_size=patch_size, device=device, dtype=dtype
|
||||||
|
@ -142,7 +145,7 @@ class CLIPImageEncoder(fl.Chain):
|
||||||
)
|
)
|
||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
),
|
),
|
||||||
fl.Lambda(func=lambda x: x[:, 0, :]),
|
fl.Lambda(func=cls_token_pooling),
|
||||||
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
|
||||||
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
|
fl.Linear(in_features=embedding_dim, out_features=output_dim, bias=False, device=device, dtype=dtype),
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import math
|
import math
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Callable, Generic, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -54,9 +54,10 @@ class FreeUBackboneFeatures(fl.Module):
|
||||||
|
|
||||||
class FreeUSkipFeatures(fl.Chain):
|
class FreeUSkipFeatures(fl.Chain):
|
||||||
def __init__(self, n: int, skip_scale: float) -> None:
|
def __init__(self, n: int, skip_scale: float) -> None:
|
||||||
|
apply_filter: Callable[[Tensor], Tensor] = lambda x: fourier_filter(x, scale=skip_scale)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
|
fl.UseContext(context="unet", key="residuals").compose(lambda residuals: residuals[n]),
|
||||||
fl.Lambda(lambda x: fourier_filter(x, scale=skip_scale)),
|
fl.Lambda(apply_filter),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,7 @@ class SD1LoraAdapter(fl.Chain, Adapter[StableDiffusion_1]):
|
||||||
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
assert metadata is not None, "Invalid safetensors checkpoint: missing metadata"
|
||||||
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
tensors = load_from_safetensors(checkpoint_path, device=target.device)
|
||||||
|
|
||||||
sub_targets = {}
|
sub_targets: dict[str, list[LoraTarget]] = {}
|
||||||
for model_name in MODELS:
|
for model_name in MODELS:
|
||||||
if not (v := metadata.get(f"{model_name}_targets", "")):
|
if not (v := metadata.get(f"{model_name}_targets", "")):
|
||||||
continue
|
continue
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from refiners.fluxion.adapters.adapter import Adapter
|
from refiners.fluxion.adapters.adapter import Adapter
|
||||||
|
@ -45,8 +47,9 @@ class SelfAttentionInjectionAdapter(Chain, Adapter[SelfAttention]):
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
|
slice_tensor: Callable[[Tensor], Tensor] = lambda x: x[:1]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
Parallel(sa_guided, Chain(Lambda(lambda x: x[:1]), target)),
|
Parallel(sa_guided, Chain(Lambda(slice_tensor), target)),
|
||||||
Lambda(self.compute_averaged_unconditioned_x),
|
Lambda(self.compute_averaged_unconditioned_x),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
from torch import Tensor, device as Device, dtype as DType
|
from torch import Tensor, device as Device, dtype as DType
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import image_to_tensor, interpolate, normalize, pad
|
from refiners.fluxion.utils import image_to_tensor, interpolate, no_grad, normalize, pad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
from refiners.foundationals.segment_anything.image_encoder import SAMViT, SAMViTH
|
||||||
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
from refiners.foundationals.segment_anything.mask_decoder import MaskDecoder
|
||||||
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
from refiners.foundationals.segment_anything.prompt_encoder import MaskEncoder, PointEncoder
|
||||||
|
@ -39,7 +39,7 @@ class SegmentAnything(fl.Module):
|
||||||
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
self.mask_encoder = mask_encoder.to(device=self.device, dtype=self.dtype)
|
||||||
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
self.mask_decoder = mask_decoder.to(device=self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
def compute_image_embedding(self, image: Image.Image) -> ImageEmbedding:
|
||||||
original_size = (image.height, image.width)
|
original_size = (image.height, image.width)
|
||||||
target_size = self.compute_target_size(original_size)
|
target_size = self.compute_target_size(original_size)
|
||||||
|
@ -48,7 +48,7 @@ class SegmentAnything(fl.Module):
|
||||||
original_image_size=original_size,
|
original_image_size=original_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
input: Image.Image | ImageEmbedding,
|
input: Image.Image | ImageEmbedding,
|
||||||
|
|
|
@ -250,7 +250,7 @@ class MonitorTimestepLoss(Callback[LatentDiffusionTrainer[Any]]):
|
||||||
self.timestep_bins[bin_index].append(loss_value)
|
self.timestep_bins[bin_index].append(loss_value)
|
||||||
|
|
||||||
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
|
def on_epoch_end(self, trainer: LatentDiffusionTrainer[Any]) -> None:
|
||||||
log_data = {}
|
log_data: dict[str, WandbLoggable] = {}
|
||||||
for bin_index, losses in self.timestep_bins.items():
|
for bin_index, losses in self.timestep_bins.items():
|
||||||
if losses:
|
if losses:
|
||||||
avg_loss = sum(losses) / len(losses)
|
avg_loss = sum(losses) / len(losses)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from torch import Tensor, cuda, device as Device, get_rng_state, no_grad, set_rng_state, stack
|
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
|
||||||
from torch.autograd import backward
|
from torch.autograd import backward
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -26,7 +26,7 @@ from torch.optim.lr_scheduler import (
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from refiners.fluxion import layers as fl
|
from refiners.fluxion import layers as fl
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
from refiners.training_utils.callback import (
|
from refiners.training_utils.callback import (
|
||||||
Callback,
|
Callback,
|
||||||
ClockCallback,
|
ClockCallback,
|
||||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed
|
from refiners.fluxion.utils import image_to_tensor, load_from_safetensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender
|
||||||
from refiners.foundationals.latent_diffusion import (
|
from refiners.foundationals.latent_diffusion import (
|
||||||
SD1ControlnetAdapter,
|
SD1ControlnetAdapter,
|
||||||
|
@ -501,7 +501,7 @@ def sdxl_ddim(
|
||||||
return sdxl
|
return sdxl
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init(
|
def test_diffusion_std_random_init(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
):
|
):
|
||||||
|
@ -529,7 +529,7 @@ def test_diffusion_std_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
ensure_similar_images(predicted_image, expected_image_std_random_init)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_karras_random_init(
|
def test_diffusion_karras_random_init(
|
||||||
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
|
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
|
||||||
):
|
):
|
||||||
|
@ -554,7 +554,7 @@ def test_diffusion_karras_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_karras_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init_float16(
|
def test_diffusion_std_random_init_float16(
|
||||||
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
sd15_std_float16: StableDiffusion_1, expected_image_std_random_init: Image.Image, test_device: torch.device
|
||||||
):
|
):
|
||||||
|
@ -583,7 +583,7 @@ def test_diffusion_std_random_init_float16(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_std_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_random_init_sag(
|
def test_diffusion_std_random_init_sag(
|
||||||
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
|
sd15_std: StableDiffusion_1, expected_image_std_random_init_sag: Image.Image, test_device: torch.device
|
||||||
):
|
):
|
||||||
|
@ -612,7 +612,7 @@ def test_diffusion_std_random_init_sag(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
|
ensure_similar_images(predicted_image, expected_image_std_random_init_sag)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_std_init_image(
|
def test_diffusion_std_init_image(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
cutecat_init: Image.Image,
|
cutecat_init: Image.Image,
|
||||||
|
@ -643,7 +643,7 @@ def test_diffusion_std_init_image(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
ensure_similar_images(predicted_image, expected_image_std_init_image)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_rectangular_init_latents(
|
def test_rectangular_init_latents(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
cutecat_init: Image.Image,
|
cutecat_init: Image.Image,
|
||||||
|
@ -658,7 +658,7 @@ def test_rectangular_init_latents(
|
||||||
assert sd15.lda.decode_latents(x).size == (width, height)
|
assert sd15.lda.decode_latents(x).size == (width, height)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_inpainting(
|
def test_diffusion_inpainting(
|
||||||
sd15_inpainting: StableDiffusion_1_Inpainting,
|
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||||
kitchen_dog: Image.Image,
|
kitchen_dog: Image.Image,
|
||||||
|
@ -692,7 +692,7 @@ def test_diffusion_inpainting(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=25, min_ssim=0.95)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_inpainting_float16(
|
def test_diffusion_inpainting_float16(
|
||||||
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
|
sd15_inpainting_float16: StableDiffusion_1_Inpainting,
|
||||||
kitchen_dog: Image.Image,
|
kitchen_dog: Image.Image,
|
||||||
|
@ -727,7 +727,7 @@ def test_diffusion_inpainting_float16(
|
||||||
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
ensure_similar_images(predicted_image, expected_image_std_inpainting, min_psnr=20, min_ssim=0.92)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_controlnet(
|
def test_diffusion_controlnet(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
|
controlnet_data: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
@ -770,7 +770,7 @@ def test_diffusion_controlnet(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_controlnet_structural_copy(
|
def test_diffusion_controlnet_structural_copy(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
@ -814,7 +814,7 @@ def test_diffusion_controlnet_structural_copy(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_controlnet_float16(
|
def test_diffusion_controlnet_float16(
|
||||||
sd15_std_float16: StableDiffusion_1,
|
sd15_std_float16: StableDiffusion_1,
|
||||||
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
controlnet_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
@ -857,7 +857,7 @@ def test_diffusion_controlnet_float16(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_controlnet_stack(
|
def test_diffusion_controlnet_stack(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
controlnet_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
@ -912,7 +912,7 @@ def test_diffusion_controlnet_stack(
|
||||||
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_controlnet_stack, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_lora(
|
def test_diffusion_lora(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
|
@ -949,7 +949,7 @@ def test_diffusion_lora(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_lora_float16(
|
def test_diffusion_lora_float16(
|
||||||
sd15_std_float16: StableDiffusion_1,
|
sd15_std_float16: StableDiffusion_1,
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
|
@ -986,7 +986,7 @@ def test_diffusion_lora_float16(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=33, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_lora_twice(
|
def test_diffusion_lora_twice(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
lora_data_pokemon: tuple[Image.Image, Path],
|
lora_data_pokemon: tuple[Image.Image, Path],
|
||||||
|
@ -1025,7 +1025,7 @@ def test_diffusion_lora_twice(
|
||||||
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_refonly(
|
def test_diffusion_refonly(
|
||||||
sd15_ddim: StableDiffusion_1,
|
sd15_ddim: StableDiffusion_1,
|
||||||
condition_image_refonly: Image.Image,
|
condition_image_refonly: Image.Image,
|
||||||
|
@ -1061,7 +1061,7 @@ def test_diffusion_refonly(
|
||||||
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_image_refonly, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_inpainting_refonly(
|
def test_diffusion_inpainting_refonly(
|
||||||
sd15_inpainting: StableDiffusion_1_Inpainting,
|
sd15_inpainting: StableDiffusion_1_Inpainting,
|
||||||
scene_image_inpainting_refonly: Image.Image,
|
scene_image_inpainting_refonly: Image.Image,
|
||||||
|
@ -1106,7 +1106,7 @@ def test_diffusion_inpainting_refonly(
|
||||||
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
ensure_similar_images(predicted_image, expected_image_inpainting_refonly, min_psnr=35, min_ssim=0.99)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_textual_inversion_random_init(
|
def test_diffusion_textual_inversion_random_init(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
expected_image_textual_inversion_random_init: Image.Image,
|
expected_image_textual_inversion_random_init: Image.Image,
|
||||||
|
@ -1141,7 +1141,7 @@ def test_diffusion_textual_inversion_random_init(
|
||||||
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_textual_inversion_random_init, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter(
|
def test_diffusion_ip_adapter(
|
||||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
ip_adapter_weights: Path,
|
ip_adapter_weights: Path,
|
||||||
|
@ -1196,7 +1196,7 @@ def test_diffusion_ip_adapter(
|
||||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_woman)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_sdxl_ip_adapter(
|
def test_diffusion_sdxl_ip_adapter(
|
||||||
sdxl_ddim: StableDiffusion_XL,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
sdxl_ip_adapter_weights: Path,
|
sdxl_ip_adapter_weights: Path,
|
||||||
|
@ -1215,7 +1215,7 @@ def test_diffusion_sdxl_ip_adapter(
|
||||||
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
ip_adapter.clip_image_encoder.load_from_safetensors(image_encoder_weights)
|
||||||
ip_adapter.inject()
|
ip_adapter.inject()
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
|
||||||
text=prompt, negative_text=negative_prompt
|
text=prompt, negative_text=negative_prompt
|
||||||
)
|
)
|
||||||
|
@ -1236,7 +1236,7 @@ def test_diffusion_sdxl_ip_adapter(
|
||||||
manual_seed(2)
|
manual_seed(2)
|
||||||
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
x = torch.randn(1, 4, 128, 128, device=test_device, dtype=torch.float16)
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
for step in sdxl.steps:
|
for step in sdxl.steps:
|
||||||
x = sdxl(
|
x = sdxl(
|
||||||
x,
|
x,
|
||||||
|
@ -1254,7 +1254,7 @@ def test_diffusion_sdxl_ip_adapter(
|
||||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_woman)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter_controlnet(
|
def test_diffusion_ip_adapter_controlnet(
|
||||||
sd15_ddim: StableDiffusion_1,
|
sd15_ddim: StableDiffusion_1,
|
||||||
ip_adapter_weights: Path,
|
ip_adapter_weights: Path,
|
||||||
|
@ -1320,7 +1320,7 @@ def test_diffusion_ip_adapter_controlnet(
|
||||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_controlnet)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_ip_adapter_plus(
|
def test_diffusion_ip_adapter_plus(
|
||||||
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
sd15_ddim_lda_ft_mse: StableDiffusion_1,
|
||||||
ip_adapter_plus_weights: Path,
|
ip_adapter_plus_weights: Path,
|
||||||
|
@ -1371,7 +1371,7 @@ def test_diffusion_ip_adapter_plus(
|
||||||
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_image_ip_adapter_plus_statue, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_diffusion_sdxl_ip_adapter_plus(
|
def test_diffusion_sdxl_ip_adapter_plus(
|
||||||
sdxl_ddim: StableDiffusion_XL,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
sdxl_ip_adapter_plus_weights: Path,
|
sdxl_ip_adapter_plus_weights: Path,
|
||||||
|
@ -1427,7 +1427,7 @@ def test_diffusion_sdxl_ip_adapter_plus(
|
||||||
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
|
ensure_similar_images(predicted_image, expected_image_sdxl_ip_adapter_plus_woman)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_sdxl_random_init(
|
def test_sdxl_random_init(
|
||||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init: Image.Image, test_device: torch.device
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1462,7 +1462,7 @@ def test_sdxl_random_init(
|
||||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(img_1=predicted_image, img_2=expected_image, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_sdxl_random_init_sag(
|
def test_sdxl_random_init_sag(
|
||||||
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
sdxl_ddim: StableDiffusion_XL, expected_sdxl_ddim_random_init_sag: Image.Image, test_device: torch.device
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -1498,7 +1498,7 @@ def test_sdxl_random_init_sag(
|
||||||
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
ensure_similar_images(img_1=predicted_image, img_2=expected_image)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion: Image.Image) -> None:
|
||||||
manual_seed(seed=2)
|
manual_seed(seed=2)
|
||||||
sd = sd15_ddim
|
sd = sd15_ddim
|
||||||
|
@ -1529,7 +1529,7 @@ def test_multi_diffusion(sd15_ddim: StableDiffusion_1, expected_multi_diffusion:
|
||||||
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(img_1=result, img_2=expected_multi_diffusion, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_t2i_adapter_depth(
|
def test_t2i_adapter_depth(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
t2i_adapter_data_depth: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
@ -1570,7 +1570,7 @@ def test_t2i_adapter_depth(
|
||||||
ensure_similar_images(predicted_image, expected_image)
|
ensure_similar_images(predicted_image, expected_image)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_t2i_adapter_xl_canny(
|
def test_t2i_adapter_xl_canny(
|
||||||
sdxl_ddim: StableDiffusion_XL,
|
sdxl_ddim: StableDiffusion_XL,
|
||||||
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
t2i_adapter_xl_data_canny: tuple[str, Image.Image, Image.Image, Path],
|
||||||
|
@ -1619,7 +1619,7 @@ def test_t2i_adapter_xl_canny(
|
||||||
ensure_similar_images(predicted_image, expected_image)
|
ensure_similar_images(predicted_image, expected_image)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_restart(
|
def test_restart(
|
||||||
sd15_ddim: StableDiffusion_1,
|
sd15_ddim: StableDiffusion_1,
|
||||||
expected_restart: Image.Image,
|
expected_restart: Image.Image,
|
||||||
|
@ -1659,7 +1659,7 @@ def test_restart(
|
||||||
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
ensure_similar_images(predicted_image, expected_restart, min_psnr=35, min_ssim=0.98)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_freeu(
|
def test_freeu(
|
||||||
sd15_std: StableDiffusion_1,
|
sd15_std: StableDiffusion_1,
|
||||||
expected_freeu: Image.Image,
|
expected_freeu: Image.Image,
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
|
||||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
|
@ -41,7 +41,7 @@ def informative_drawings_model(informative_drawings_weights: Path, test_device:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_preprocessor_informative_drawing(
|
def test_preprocessor_informative_drawing(
|
||||||
informative_drawings_model: InformativeDrawings,
|
informative_drawings_model: InformativeDrawings,
|
||||||
cutecat_init: Image.Image,
|
cutecat_init: Image.Image,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import Any, Callable
|
||||||
from warnings import warn
|
from warnings import warn
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -60,8 +61,9 @@ def test_converter_multiple_tensors(test_device: torch.device) -> None:
|
||||||
|
|
||||||
|
|
||||||
def test_converter_no_parent_device_or_dtype() -> None:
|
def test_converter_no_parent_device_or_dtype() -> None:
|
||||||
|
identity: Callable[[Any], Any] = lambda x: x
|
||||||
chain = fl.Chain(
|
chain = fl.Chain(
|
||||||
fl.Lambda(func=(lambda x: x)),
|
fl.Lambda(func=identity),
|
||||||
fl.Converter(set_device=True, set_dtype=False),
|
fl.Converter(set_device=True, set_dtype=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from PIL import Image
|
||||||
from torch import device as Device, dtype as DType
|
from torch import device as Device, dtype as DType
|
||||||
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
from torchvision.transforms.functional import gaussian_blur as torch_gaussian_blur # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, tensor_to_image
|
from refiners.fluxion.utils import gaussian_blur, image_to_tensor, manual_seed, no_grad, tensor_to_image
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -62,3 +62,18 @@ def test_tensor_to_image() -> None:
|
||||||
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
assert tensor_to_image(torch.zeros(1, 3, 512, 512)).mode == "RGB"
|
||||||
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
assert tensor_to_image(torch.zeros(1, 1, 512, 512)).mode == "L"
|
||||||
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
assert tensor_to_image(torch.zeros(1, 4, 512, 512)).mode == "RGBA"
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_grad() -> None:
|
||||||
|
x = torch.randn(1, 1, requires_grad=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y = x + 1
|
||||||
|
assert not y.requires_grad
|
||||||
|
|
||||||
|
with no_grad():
|
||||||
|
z = x + 1
|
||||||
|
assert not z.requires_grad
|
||||||
|
|
||||||
|
w = x + 1
|
||||||
|
assert w.requires_grad
|
||||||
|
|
|
@ -7,7 +7,7 @@ import transformers # type: ignore
|
||||||
from diffusers import StableDiffusionPipeline # type: ignore
|
from diffusers import StableDiffusionPipeline # type: ignore
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||||
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
from refiners.foundationals.clip.concepts import ConceptExtender, TokenExtender
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
@ -124,7 +124,7 @@ def test_encoder(
|
||||||
our_tokens = tokenizer(prompt)
|
our_tokens = tokenizer(prompt)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
|
ref_embeddings = ref_encoder_with_new_concepts(ref_tokens.to(test_device))[0]
|
||||||
our_embeddings = our_encoder_with_new_concepts(prompt)
|
our_embeddings = our_encoder_with_new_concepts(prompt)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import CLIPVisionModelWithProjection # type: ignore
|
from transformers import CLIPVisionModelWithProjection # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||||
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ def test_encoder(
|
||||||
):
|
):
|
||||||
x = torch.randn(1, 3, 224, 224).to(test_device)
|
x = torch.randn(1, 3, 224, 224).to(test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
ref_embeddings = ref_encoder(x).image_embeds
|
ref_embeddings = ref_encoder(x).image_embeds
|
||||||
our_embeddings = our_encoder(x)
|
our_embeddings = our_encoder(x)
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
import transformers # type: ignore
|
import transformers # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||||
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
|
||||||
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
from refiners.foundationals.clip.tokenizer import CLIPTokenizer
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ def test_encoder(
|
||||||
our_tokens = tokenizer(prompt)
|
our_tokens = tokenizer(prompt)
|
||||||
assert torch.equal(our_tokens, ref_tokens)
|
assert torch.equal(our_tokens, ref_tokens)
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
ref_embeddings = ref_encoder(ref_tokens.to(test_device))[0]
|
||||||
our_embeddings = our_encoder(prompt)
|
our_embeddings = our_encoder(prompt)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from transformers import AutoModel # type: ignore
|
from transformers import AutoModel # type: ignore
|
||||||
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors, manual_seed
|
from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
|
||||||
from refiners.foundationals.dinov2 import (
|
from refiners.foundationals.dinov2 import (
|
||||||
DINOv2_base,
|
DINOv2_base,
|
||||||
DINOv2_base_reg,
|
DINOv2_base_reg,
|
||||||
|
@ -124,7 +124,7 @@ def test_encoder(
|
||||||
|
|
||||||
x = torch.randn(1, 3, 518, 518).to(test_device)
|
x = torch.randn(1, 3, 518, 518).to(test_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
ref_features = ref_backbone(x).last_hidden_state
|
ref_features = ref_backbone(x).last_hidden_state
|
||||||
our_features = our_backbone(x)
|
our_features = our_backbone(x)
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tests.utils import ensure_similar_images
|
from tests.utils import ensure_similar_images
|
||||||
|
|
||||||
from refiners.fluxion.utils import load_from_safetensors
|
from refiners.fluxion.utils import load_from_safetensors, no_grad
|
||||||
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def sample_image(ref_path: Path) -> Image.Image:
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
def test_encode_decode(encoder: LatentDiffusionAutoencoder, sample_image: Image.Image):
|
||||||
encoded = encoder.encode_image(sample_image)
|
encoded = encoder.encode_image(sample_image)
|
||||||
decoded = encoder.decode_latents(encoded)
|
decoded = encoder.decode_latents(encoded)
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.adapters.adapter import lookup_top_adapter
|
from refiners.fluxion.adapters.adapter import lookup_top_adapter
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet
|
from refiners.foundationals.latent_diffusion import SD1ControlnetAdapter, SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_1.controlnet import Controlnet
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ def unet(request: pytest.FixtureRequest) -> Iterator[SD1UNet]:
|
||||||
yield unet
|
yield unet
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_single_controlnet(unet: SD1UNet) -> None:
|
def test_single_controlnet(unet: SD1UNet) -> None:
|
||||||
original_parent = unet.parent
|
original_parent = unet.parent
|
||||||
cn = SD1ControlnetAdapter(unet, name="cn")
|
cn = SD1ControlnetAdapter(unet, name="cn")
|
||||||
|
@ -43,7 +43,7 @@ def test_single_controlnet(unet: SD1UNet) -> None:
|
||||||
assert len(list(unet.walk(Controlnet))) == 0
|
assert len(list(unet.walk(Controlnet))) == 0
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
||||||
original_parent = unet.parent
|
original_parent = unet.parent
|
||||||
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||||||
|
@ -71,7 +71,7 @@ def test_two_controlnets_eject_bottom_up(unet: SD1UNet) -> None:
|
||||||
assert len(list(unet.walk(Controlnet))) == 0
|
assert len(list(unet.walk(Controlnet))) == 0
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
||||||
original_parent = unet.parent
|
original_parent = unet.parent
|
||||||
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
cn1 = SD1ControlnetAdapter(unet, name="cn1").inject()
|
||||||
|
@ -86,7 +86,7 @@ def test_two_controlnets_eject_top_down(unet: SD1UNet) -> None:
|
||||||
assert len(list(unet.walk(Controlnet))) == 0
|
assert len(list(unet.walk(Controlnet))) == 0
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
|
def test_two_controlnets_same_name(unet: SD1UNet) -> None:
|
||||||
SD1ControlnetAdapter(unet, name="cnx").inject()
|
SD1ControlnetAdapter(unet, name="cnx").inject()
|
||||||
cn2 = SD1ControlnetAdapter(unet, name="cnx")
|
cn2 = SD1ControlnetAdapter(unet, name="cnx")
|
||||||
|
|
|
@ -4,6 +4,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
from refiners.foundationals.latent_diffusion import SD1UNet, SDXLUNet
|
||||||
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
from refiners.foundationals.latent_diffusion.freeu import FreeUResidualConcatenator, SDFreeUAdapter
|
||||||
|
|
||||||
|
@ -52,14 +53,14 @@ def test_freeu_identity_scales() -> None:
|
||||||
unet = SD1UNet(in_channels=4)
|
unet = SD1UNet(in_channels=4)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
unet.set_timestep(timestep=timestep)
|
unet.set_timestep(timestep=timestep)
|
||||||
y_1 = unet(x.clone())
|
y_1 = unet(x.clone())
|
||||||
|
|
||||||
freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0])
|
freeu = SDFreeUAdapter(unet, backbone_scales=[1.0, 1.0], skip_scales=[1.0, 1.0])
|
||||||
freeu.inject()
|
freeu.inject()
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
unet.set_timestep(timestep=timestep)
|
unet.set_timestep(timestep=timestep)
|
||||||
y_2 = unet(x.clone())
|
y_2 = unet(x.clone())
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock
|
||||||
from refiners.foundationals.latent_diffusion.reference_only_control import (
|
from refiners.foundationals.latent_diffusion.reference_only_control import (
|
||||||
|
@ -11,7 +11,7 @@ from refiners.foundationals.latent_diffusion.reference_only_control import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_refonly_inject_eject() -> None:
|
def test_refonly_inject_eject() -> None:
|
||||||
unet = SD1UNet(in_channels=9)
|
unet = SD1UNet(in_channels=9)
|
||||||
adapter = ReferenceOnlyControlAdapter(unet)
|
adapter = ReferenceOnlyControlAdapter(unet)
|
||||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import refiners.fluxion.layers as fl
|
import refiners.fluxion.layers as fl
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.text_encoder import DoubleTextEncoder
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,7 +65,7 @@ def double_text_encoder(double_text_encoder_weights: Path) -> DoubleTextEncoder:
|
||||||
return double_text_encoder
|
return double_text_encoder
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
|
def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder: DoubleTextEncoder) -> None:
|
||||||
manual_seed(seed=0)
|
manual_seed(seed=0)
|
||||||
prompt = "A photo of a pizza."
|
prompt = "A photo of a pizza."
|
||||||
|
|
|
@ -6,7 +6,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
from refiners.fluxion.model_converter import ConversionStage, ModelConverter
|
||||||
from refiners.fluxion.utils import manual_seed
|
from refiners.fluxion.utils import manual_seed, no_grad
|
||||||
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
from refiners.foundationals.latent_diffusion.stable_diffusion_xl import SDXLUNet
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ def refiners_sdxl_unet() -> SDXLUNet:
|
||||||
return unet
|
return unet
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
def test_sdxl_unet(diffusers_sdxl_unet: Any, refiners_sdxl_unet: SDXLUNet) -> None:
|
||||||
source = diffusers_sdxl_unet
|
source = diffusers_sdxl_unet
|
||||||
target = refiners_sdxl_unet
|
target = refiners_sdxl_unet
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
|
from refiners.fluxion.utils import no_grad
|
||||||
from refiners.foundationals.latent_diffusion import SD1UNet
|
from refiners.foundationals.latent_diffusion import SD1UNet
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,11 +14,11 @@ def test_unet_context_flush():
|
||||||
unet = SD1UNet(in_channels=4)
|
unet = SD1UNet(in_channels=4)
|
||||||
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
unet.set_clip_text_embedding(clip_text_embedding=text_embedding) # not flushed between forward-s
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
unet.set_timestep(timestep=timestep)
|
unet.set_timestep(timestep=timestep)
|
||||||
y_1 = unet(x.clone())
|
y_1 = unet(x.clone())
|
||||||
|
|
||||||
with torch.no_grad():
|
with no_grad():
|
||||||
unet.set_timestep(timestep=timestep)
|
unet.set_timestep(timestep=timestep)
|
||||||
y_2 = unet(x.clone())
|
y_2 = unet(x.clone())
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from torch import Tensor
|
||||||
|
|
||||||
from refiners.fluxion import manual_seed
|
from refiners.fluxion import manual_seed
|
||||||
from refiners.fluxion.model_converter import ModelConverter
|
from refiners.fluxion.model_converter import ModelConverter
|
||||||
from refiners.fluxion.utils import image_to_tensor
|
from refiners.fluxion.utils import image_to_tensor, no_grad
|
||||||
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
from refiners.foundationals.segment_anything.image_encoder import FusedSelfAttention
|
||||||
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
from refiners.foundationals.segment_anything.model import SegmentAnythingH
|
||||||
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
from refiners.foundationals.segment_anything.transformer import TwoWayTranformerLayer
|
||||||
|
@ -98,7 +98,7 @@ def truck(ref_path: Path) -> Image.Image:
|
||||||
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
return Image.open(ref_path / "truck.jpg").convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
||||||
manual_seed(seed=0)
|
manual_seed(seed=0)
|
||||||
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
|
x = torch.randn(25, 14, 14, 1280, device=facebook_sam_h.device)
|
||||||
|
@ -124,7 +124,7 @@ def test_fused_self_attention(facebook_sam_h: FacebookSAM) -> None:
|
||||||
assert torch.equal(input=y_1, other=y_2)
|
assert torch.equal(input=y_1, other=y_2)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, truck: Image.Image) -> None:
|
||||||
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
image_tensor = image_to_tensor(image=truck.resize(size=(1024, 1024)), device=facebook_sam_h.device)
|
||||||
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
y_1 = facebook_sam_h.image_encoder(image_tensor)
|
||||||
|
@ -133,7 +133,7 @@ def test_image_encoder(sam_h: SegmentAnythingH, facebook_sam_h: FacebookSAM, tru
|
||||||
assert torch.allclose(input=y_1, other=y_2, atol=1e-4)
|
assert torch.allclose(input=y_1, other=y_2, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||||
refiners_prompt_encoder = sam_h.point_encoder
|
refiners_prompt_encoder = sam_h.point_encoder
|
||||||
|
@ -144,7 +144,7 @@ def test_prompt_encoder_dense_positional_embedding(facebook_sam_h: FacebookSAM,
|
||||||
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||||
refiners_prompt_encoder = sam_h.mask_encoder
|
refiners_prompt_encoder = sam_h.mask_encoder
|
||||||
|
@ -155,7 +155,7 @@ def test_prompt_encoder_no_mask_dense_embedding(facebook_sam_h: FacebookSAM, sam
|
||||||
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
assert torch.equal(input=refiners_dense_pe, other=facebook_dense_pe)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
|
def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, prompt: SAMPrompt) -> None:
|
||||||
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
facebook_prompt_encoder = facebook_sam_h.prompt_encoder
|
||||||
refiners_prompt_encoder = sam_h.point_encoder
|
refiners_prompt_encoder = sam_h.point_encoder
|
||||||
|
@ -174,7 +174,7 @@ def test_point_encoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH, pro
|
||||||
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
|
assert torch.equal(input=refiners_sparse_pe, other=facebook_sparse_pe)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
||||||
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
dense_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||||
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
dense_positional_embedding = torch.randn(1, 64 * 64, 256, device=facebook_sam_h.device)
|
||||||
|
@ -223,7 +223,7 @@ def test_two_way_transformer(facebook_sam_h: FacebookSAM) -> None:
|
||||||
assert torch.equal(input=y_1, other=y_2)
|
assert torch.equal(input=y_1, other=y_2)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@no_grad()
|
||||||
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> None:
|
||||||
manual_seed(seed=0)
|
manual_seed(seed=0)
|
||||||
facebook_mask_decoder = facebook_sam_h.mask_decoder
|
facebook_mask_decoder = facebook_sam_h.mask_decoder
|
||||||
|
|
Loading…
Reference in a new issue